Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);

int start_k_tile = blockIdx.x * block_k_tiles;
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
Expand Down Expand Up @@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
int4* sh_ptr = sh + stage_size * pipe;

if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;

int first_k = k_tile_id * tile_k_size;

Expand All @@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
return;
}

int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;

if (warp_id >= 4) {
return;
Expand Down Expand Up @@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
}
28 changes: 14 additions & 14 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int lda, int block_rows) {
int start_row = block_rows * blockIdx.x;
auto start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
Expand All @@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int base_k = 0;

for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];

out_half[cur_k] = a_row_half[src_pos];
Expand All @@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,

if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];

out_half[cur_k] = a_row_half[src_pos];
Expand Down Expand Up @@ -723,8 +723,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;

// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
Expand All @@ -743,7 +743,7 @@ __global__ void Marlin(
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;

// Zero-points
Expand All @@ -756,7 +756,7 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;

// We use a different scale layout for grouped and column-wise quantization as
Expand Down Expand Up @@ -1047,7 +1047,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;

int warp_row = warp_id / n_warps;
Expand Down Expand Up @@ -1085,7 +1085,7 @@ __global__ void Marlin(

// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N

Expand All @@ -1094,7 +1094,7 @@ __global__ void Marlin(

cur_k += warp_row * 16;

int th_id = threadIdx.x % 32;
auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix

int s_col_shift =
Expand Down Expand Up @@ -1159,7 +1159,7 @@ __global__ void Marlin(
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;

int warp_row = warp_id / n_warps;
Expand Down Expand Up @@ -1197,7 +1197,7 @@ __global__ void Marlin(
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;

int warp_row = warp_id / n_warps;
Expand Down Expand Up @@ -1323,7 +1323,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
Expand Down Expand Up @@ -1390,7 +1390,7 @@ __global__ void Marlin(
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
auto c_sh_wr = threadIdx.x;

int row = (threadIdx.x % 32) / 4;

Expand Down
16 changes: 8 additions & 8 deletions csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);

int start_k_tile = blockIdx.x * block_k_tiles;
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
Expand Down Expand Up @@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(

if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;

uint32_t const* sh_perm_int_ptr =
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
Expand All @@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(

} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;

int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
Expand All @@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
return;
}

int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;

if (warp_id >= 4) {
return;
Expand Down Expand Up @@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
}
10 changes: 5 additions & 5 deletions csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,12 @@ __global__ void Marlin(
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
auto b_sh_wr = threadIdx.x;
auto b_sh_rd = threadIdx.x;

int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
Expand Down Expand Up @@ -455,7 +455,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
auto red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
Expand Down Expand Up @@ -522,7 +522,7 @@ __global__ void Marlin(
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
auto c_sh_wr = threadIdx.x;

int row = (threadIdx.x % 32) / 4;

Expand Down
14 changes: 7 additions & 7 deletions csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ __global__ void Marlin(
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
auto b_sh_wr = threadIdx.x;
auto b_sh_rd = threadIdx.x;

int s_tok_gl_rd = threadIdx.x;
auto s_tok_gl_rd = threadIdx.x;
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
Expand All @@ -368,8 +368,8 @@ __global__ void Marlin(
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;

int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
int s_ch_sh_wr = threadIdx.x;
auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
auto s_ch_sh_wr = threadIdx.x;
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
2 * ((threadIdx.x % 32) % 4);
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
Expand Down Expand Up @@ -558,7 +558,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
auto red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
Expand Down Expand Up @@ -628,7 +628,7 @@ __global__ void Marlin(
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
c_gl_wr += (4 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads * 2;
int c_sh_wr = 2 * threadIdx.x;
auto c_sh_wr = 2 * threadIdx.x;

int row = (threadIdx.x % 32) / 4;

Expand Down
14 changes: 7 additions & 7 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,15 @@ __global__ void Marlin_24(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;

int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
(threadIdx.x % (m_sh_stride));
m_gl_rd += (m_sh_stride)*slice_col;
m_gl_rd += m_gl_rd_delta_o * slice_row;
int m_sh_wr = threadIdx.x;
int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
auto m_sh_wr = threadIdx.x;
auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;

int s_gl_rd;
if constexpr (group_blocks == -1) {
Expand All @@ -291,7 +291,7 @@ __global__ void Marlin_24(
s_sh_stride * slice_col + threadIdx.x;
}

int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
Expand Down Expand Up @@ -516,7 +516,7 @@ __global__ void Marlin_24(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
Expand Down Expand Up @@ -583,7 +583,7 @@ __global__ void Marlin_24(
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
auto c_sh_wr = threadIdx.x;

int col = 2 * ((threadIdx.x % 32) % 4);

Expand Down
Loading
Loading