Skip to content

Commit e00e8ce

Browse files
authored
bugfix: move cum_m calculation inside kernels (#1060)
1 parent a9935ea commit e00e8ce

File tree

4 files changed

+17
-20
lines changed

4 files changed

+17
-20
lines changed

csrc/group_gemm_groupwise_sm100.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ using namespace flashinfer;
5151
void CutlassGroupGemmGroupwiseScaledSM100(at::Tensor int_workspace_buffer,
5252
at::Tensor float_workspace_buffer, at::Tensor A,
5353
at::Tensor B, at::Tensor SFA, at::Tensor SFB,
54-
at::Tensor C, at::Tensor m_indptr, int64_t cum_m,
55-
int64_t n, int64_t k, int64_t scale_granularity_m,
56-
int64_t scale_granularity_n,
54+
at::Tensor C, at::Tensor m_indptr, int64_t n, int64_t k,
55+
int64_t scale_granularity_m, int64_t scale_granularity_n,
5756
int64_t scale_granularity_k) {
5857
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
5958
auto stream = at::cuda::getCurrentCUDAStream();
@@ -72,8 +71,8 @@ void CutlassGroupGemmGroupwiseScaledSM100(at::Tensor int_workspace_buffer,
7271
float_workspace_buffer.element_size() * float_workspace_buffer.size(0),
7372
static_cast<cutlass_t_in*>(A.data_ptr()), static_cast<cutlass_t_in*>(B.data_ptr()),
7473
static_cast<float*>(SFA.data_ptr()), static_cast<float*>(SFB.data_ptr()),
75-
static_cast<cutlass_t_out*>(C.data_ptr()), static_cast<int*>(m_indptr.data_ptr()),
76-
cum_m, n, k, batch_size, stream);
74+
static_cast<cutlass_t_out*>(C.data_ptr()), static_cast<int*>(m_indptr.data_ptr()), n,
75+
k, batch_size, stream);
7776
return true;
7877
});
7978
});

csrc/group_gemm_sm100_pybind.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
void CutlassGroupGemmGroupwiseScaledSM100(at::Tensor int_workspace_buffer,
1919
at::Tensor float_workspace_buffer, at::Tensor A,
2020
at::Tensor B, at::Tensor SFA, at::Tensor SFB,
21-
at::Tensor C, at::Tensor m_indptr, int64_t cum_m,
22-
int64_t n, int64_t k, int64_t scale_granularity_m,
23-
int64_t scale_granularity_n, int64_t scale_granularity_k);
21+
at::Tensor C, at::Tensor m_indptr, int64_t n, int64_t k,
22+
int64_t scale_granularity_m, int64_t scale_granularity_n,
23+
int64_t scale_granularity_k);
2424

2525
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
2626
m.def("group_gemm_fp8_nt_groupwise", CutlassGroupGemmGroupwiseScaledSM100);

flashinfer/gemm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,6 @@ def group_gemm_fp8_nt_groupwise(
871871
b_scale,
872872
out,
873873
m_indptr,
874-
m_indptr[-1],
875874
n,
876875
k,
877876
*scale_granularity_mnk,
@@ -914,7 +913,6 @@ def group_gemm_fp8_nt_groupwise(
914913
b_scale,
915914
padded_out,
916915
padded_m_indptr,
917-
padded_m_indptr[-1],
918916
n,
919917
k,
920918
*scale_granularity_mnk,

include/flashinfer/gemm/group_gemm_groupwise_sm100.cuh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,19 @@ template <typename ScaleConfig, typename DTypeIn, typename DTypeSF, typename DTy
3333
typename ProblemShape, typename StrideA, typename StrideB, typename StrideC,
3434
typename LayoutSFA, typename LayoutSFB>
3535
__global__ void compute_sm100_cutlass_group_gemm_args(
36-
DTypeIn* A, DTypeIn* B, DTypeSF* SFA, DTypeSF* SFB, DTypeOut* C, int* m_indptr, int cum_m,
37-
int n, int k, int batch_size, int scale_granularity_m, int scale_granularity_n,
38-
int scale_granularity_k, ProblemShape* problem_sizes, const DTypeIn** A_ptr,
39-
const DTypeIn** B_ptr, const DTypeSF** SFA_ptr, const DTypeSF** SFB_ptr, const DTypeOut** C_ptr,
40-
DTypeOut** D_ptr, StrideA* stride_A, StrideB* stride_B, StrideC* stride_C,
41-
LayoutSFA* layout_SFA, LayoutSFB* layout_SFB) {
36+
DTypeIn* A, DTypeIn* B, DTypeSF* SFA, DTypeSF* SFB, DTypeOut* C, int* m_indptr, int n, int k,
37+
int batch_size, int scale_granularity_m, int scale_granularity_n, int scale_granularity_k,
38+
ProblemShape* problem_sizes, const DTypeIn** A_ptr, const DTypeIn** B_ptr,
39+
const DTypeSF** SFA_ptr, const DTypeSF** SFB_ptr, const DTypeOut** C_ptr, DTypeOut** D_ptr,
40+
StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, LayoutSFA* layout_SFA,
41+
LayoutSFB* layout_SFB) {
4242
int i = blockIdx.x;
4343
int m = m_indptr[i + 1] - m_indptr[i];
4444
problem_sizes[i] = ProblemShape(m, n, k);
4545
stride_A[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
4646
stride_B[i] = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
4747
stride_C[i] = cutlass::make_cute_packed_stride(StrideC{}, {m, n, 1});
48-
layout_SFA[i] = ScaleConfig::tile_atom_to_shape_SFA(make_shape(cum_m, n, k, 1));
48+
layout_SFA[i] = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m_indptr[batch_size], n, k, 1));
4949
layout_SFB[i] = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
5050
A_ptr[i] = A + m_indptr[i] * k;
5151
B_ptr[i] = B + i * k * n;
@@ -61,8 +61,8 @@ cudaError_t CutlassGroupwiseScaledGroupGEMMSM100(void* int_buffer, size_t int_bu
6161
void* float_buffer,
6262
size_t float_buffer_size_in_bytes, DTypeIn* A,
6363
DTypeIn* B, float* SFA, float* SFB, DTypeOut* C,
64-
int* m_indptr, int cum_m, int n, int k,
65-
int batch_size, cudaStream_t stream) {
64+
int* m_indptr, int n, int k, int batch_size,
65+
cudaStream_t stream) {
6666
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group
6767

6868
using ElementA = DTypeIn; // Element type for A matrix operand
@@ -159,7 +159,7 @@ cudaError_t CutlassGroupwiseScaledGroupGEMMSM100(void* int_buffer, size_t int_bu
159159
"sm100_groupwise_group_gemm_layout_SFB");
160160

161161
compute_sm100_cutlass_group_gemm_args<ScaleConfig><<<batch_size, 1, 0, stream>>>(
162-
A, B, SFA, SFB, C, m_indptr, cum_m, n, k, batch_size, ScaleGranularityM, ScaleGranularityN,
162+
A, B, SFA, SFB, C, m_indptr, n, k, batch_size, ScaleGranularityM, ScaleGranularityN,
163163
ScaleGranularityK, problem_sizes, A_ptr, B_ptr, SFA_ptr, SFB_ptr, C_ptr, D_ptr, stride_A,
164164
stride_B, stride_C, layout_SFA, layout_SFB);
165165

0 commit comments

Comments
 (0)