Skip to content

Commit

Permalink
[Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize…
Browse files Browse the repository at this point in the history
…, 2x throughput (vllm-project#8248)
  • Loading branch information
rasmith authored Sep 6, 2024
1 parent 1447c97 commit 9db52ea
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def awq_dequantize_kernel(

# Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]

masks_y = offsets_y < num_rows
Expand All @@ -43,6 +43,9 @@ def awq_dequantize_kernel(

# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)

# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
Expand All @@ -59,9 +62,8 @@ def awq_dequantize_kernel(
iweights = (iweights >> shifts) & 0xF

# Compute zero offsets and masks.
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
tl.arange(0, BLOCK_SIZE_Y) // group_size)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]

zero_masks_y = zero_offsets_y < num_rows // group_size
Expand All @@ -70,13 +72,16 @@ def awq_dequantize_kernel(

# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF

# Compute scale offsets and masks.
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
tl.arange(0, BLOCK_SIZE_Y) // group_size)
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
tl.arange(0, BLOCK_SIZE_X * 8))
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
Expand All @@ -87,6 +92,7 @@ def awq_dequantize_kernel(

# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Dequantize.
iweights = (iweights - zeros) * scales
Expand Down Expand Up @@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
masks_am = offsets_am < M

offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) +
tl.arange(0, BLOCK_SIZE_N) // 8)
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_bn = offsets_bn < N // 8

offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) +
tl.arange(0, BLOCK_SIZE_N) // 8)
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
masks_zn = offsets_zn < N // 8

offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand All @@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,

masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)

# Dequantize b.
offsets_szk = (
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
tl.arange(0, BLOCK_SIZE_K) // group_size)
tl.arange(0, 1))
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))

offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))

b = (b >> shifts) & 0xF
zeros = (zeros >> shifts) & 0xF
Expand Down

0 comments on commit 9db52ea

Please sign in to comment.