Skip to content

Commit

Permalink
fix illegal memory access of GEMV kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
周鹤云 committed Jun 12, 2024
1 parent e25b350 commit 7cb66b0
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ torch::Tensor gemv_forward_cuda_new(
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);

constexpr int kSmemByteSizePerBatch = N_PER_BLOCK * K_INTERLEAVE * BLOCK_SIZE;
// if (group_size == 64)
// {
// gemv_kernel_g64<<<num_blocks, num_threads>>>(
Expand All @@ -261,37 +262,37 @@ torch::Tensor gemv_forward_cuda_new(
switch (m)
{
case 1:
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 1>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 2:
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 2>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 3:
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 3>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 4:
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 4>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 5:
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 5>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 6:
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 6>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 7:
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 7>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
Expand Down

0 comments on commit 7cb66b0

Please sign in to comment.