Skip to content

Commit

Permalink
Revert "Revert "ignoring group size 32 from the compilation path""
Browse files Browse the repository at this point in the history
This reverts commit 89b66e4.
  • Loading branch information
HanGuo97 committed Dec 30, 2024
1 parent 89b66e4 commit 85adde6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
3 changes: 0 additions & 3 deletions flute/csrc/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ qgemm_raw_simple(const at::Tensor& input,
do { \
switch (group_size) \
{ \
case 32: \
RUN_QGEMM_RAW(T, NUM_BITS, 32); \
break; \
case 64: \
RUN_QGEMM_RAW(T, NUM_BITS, 64); \
break; \
Expand Down
12 changes: 6 additions & 6 deletions flute/csrc/qgemm_kernel_raw_generated.cu
Original file line number Diff line number Diff line change
Expand Up @@ -798,28 +798,28 @@ _qgemm_raw(int M,
const cudaStream_t stream)


INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 2, 32);
// INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 2, 32);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 2, 64);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 2, 128);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 2, 256);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 3, 32);
// INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 3, 32);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 3, 64);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 3, 128);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 3, 256);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 4, 32);
// INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 4, 32);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 4, 64);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 4, 128);
INSTANTIATE_TEMPLATE(cute::half_t , cute::uint16_t, __half2 , 4, 256);

INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 2, 32);
// INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 2, 32);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 2, 64);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 2, 128);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 2, 256);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 3, 32);
// INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 3, 32);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 3, 64);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 3, 128);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 3, 256);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 32);
// INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 32);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 64);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 128);
INSTANTIATE_TEMPLATE(cute::bfloat16_t, cute::uint16_t, __nv_bfloat162, 4, 256);

0 comments on commit 85adde6

Please sign in to comment.