Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput #8248

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def awq_dequantize_kernel(

# Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]

masks_y = offsets_y < num_rows
Expand All @@ -43,6 +43,9 @@ def awq_dequantize_kernel(

# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)

# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
Expand All @@ -59,9 +62,8 @@ def awq_dequantize_kernel(
iweights = (iweights >> shifts) & 0xF

# Compute zero offsets and masks.
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
tl.arange(0, BLOCK_SIZE_Y) // group_size)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]

zero_masks_y = zero_offsets_y < num_rows // group_size
Expand All @@ -70,13 +72,16 @@ def awq_dequantize_kernel(

# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF

# Compute scale offsets and masks.
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
tl.arange(0, BLOCK_SIZE_Y) // group_size)
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
tl.arange(0, BLOCK_SIZE_X * 8))
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
Expand All @@ -87,6 +92,7 @@ def awq_dequantize_kernel(

# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Dequantize.
iweights = (iweights - zeros) * scales
Expand Down Expand Up @@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
masks_am = offsets_am < M

offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) +
tl.arange(0, BLOCK_SIZE_N) // 8)
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_bn = offsets_bn < N // 8

offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) +
tl.arange(0, BLOCK_SIZE_N) // 8)
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_zn = offsets_zn < N // 8

offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand All @@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,

masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)

# Dequantize b.
offsets_szk = (
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
tl.arange(0, BLOCK_SIZE_K) // group_size)
tl.arange(0, 1))
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))

offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))

b = (b >> shifts) & 0xF
zeros = (zeros >> shifts) & 0xF
Expand Down
Loading