@@ -33,19 +33,19 @@ template <typename ScaleConfig, typename DTypeIn, typename DTypeSF, typename DTy
3333 typename ProblemShape, typename StrideA, typename StrideB, typename StrideC,
3434 typename LayoutSFA, typename LayoutSFB>
3535__global__ void compute_sm100_cutlass_group_gemm_args (
36- DTypeIn* A, DTypeIn* B, DTypeSF* SFA, DTypeSF* SFB, DTypeOut* C, int * m_indptr, int cum_m ,
37- int n , int k , int batch_size , int scale_granularity_m, int scale_granularity_n ,
38- int scale_granularity_k, ProblemShape* problem_sizes, const DTypeIn** A_ptr,
39- const DTypeIn ** B_ptr , const DTypeSF** SFA_ptr , const DTypeSF ** SFB_ptr, const DTypeOut** C_ptr ,
40- DTypeOut** D_ptr, StrideA* stride_A, StrideB* stride_B, StrideC* stride_C,
41- LayoutSFA* layout_SFA, LayoutSFB* layout_SFB) {
36+ DTypeIn* A, DTypeIn* B, DTypeSF* SFA, DTypeSF* SFB, DTypeOut* C, int * m_indptr, int n, int k ,
37+ int batch_size , int scale_granularity_m , int scale_granularity_n , int scale_granularity_k ,
38+ ProblemShape* problem_sizes, const DTypeIn** A_ptr, const DTypeIn** B_ptr ,
39+ const DTypeSF ** SFA_ptr , const DTypeSF** SFB_ptr , const DTypeOut ** C_ptr, DTypeOut** D_ptr ,
40+ StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, LayoutSFA* layout_SFA ,
41+ LayoutSFB* layout_SFB) {
4242 int i = blockIdx .x ;
4343 int m = m_indptr[i + 1 ] - m_indptr[i];
4444 problem_sizes[i] = ProblemShape (m, n, k);
4545 stride_A[i] = cutlass::make_cute_packed_stride (StrideA{}, {m, k, 1 });
4646 stride_B[i] = cutlass::make_cute_packed_stride (StrideB{}, {n, k, 1 });
4747 stride_C[i] = cutlass::make_cute_packed_stride (StrideC{}, {m, n, 1 });
48- layout_SFA[i] = ScaleConfig::tile_atom_to_shape_SFA (make_shape (cum_m , n, k, 1 ));
48+ layout_SFA[i] = ScaleConfig::tile_atom_to_shape_SFA (make_shape (m_indptr[batch_size] , n, k, 1 ));
4949 layout_SFB[i] = ScaleConfig::tile_atom_to_shape_SFB (make_shape (m, n, k, 1 ));
5050 A_ptr[i] = A + m_indptr[i] * k;
5151 B_ptr[i] = B + i * k * n;
@@ -61,8 +61,8 @@ cudaError_t CutlassGroupwiseScaledGroupGEMMSM100(void* int_buffer, size_t int_bu
6161 void * float_buffer,
6262 size_t float_buffer_size_in_bytes, DTypeIn* A,
6363 DTypeIn* B, float * SFA, float * SFB, DTypeOut* C,
64- int * m_indptr, int cum_m , int n , int k ,
65- int batch_size, cudaStream_t stream) {
64+ int * m_indptr, int n , int k , int batch_size ,
65+ cudaStream_t stream) {
6666 using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int , int , int >>; // <M,N,K> per group
6767
6868 using ElementA = DTypeIn; // Element type for A matrix operand
@@ -159,7 +159,7 @@ cudaError_t CutlassGroupwiseScaledGroupGEMMSM100(void* int_buffer, size_t int_bu
159159 " sm100_groupwise_group_gemm_layout_SFB" );
160160
161161 compute_sm100_cutlass_group_gemm_args<ScaleConfig><<<batch_size, 1 , 0 , stream>>> (
162- A, B, SFA, SFB, C, m_indptr, cum_m, n, k, batch_size, ScaleGranularityM, ScaleGranularityN,
162+ A, B, SFA, SFB, C, m_indptr, n, k, batch_size, ScaleGranularityM, ScaleGranularityN,
163163 ScaleGranularityK, problem_sizes, A_ptr, B_ptr, SFA_ptr, SFB_ptr, C_ptr, D_ptr, stride_A,
164164 stride_B, stride_C, layout_SFA, layout_SFB);
165165
0 commit comments