Skip to content

Commit

Permalink
Fix out-of-bounds memory access in Galore dequant kernel (pytorch#1125)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Oct 21, 2024
1 parent a2faafe commit f33cff7
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torchao/prototype/galore/kernels/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def _dequant_kernel(
dq_ptr,
stride_qm,
stride_qn,
M,
N,
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand All @@ -22,17 +24,18 @@ def _dequant_kernel(
# rm = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
# rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
offsets = rm[:, None] * stride_qm + rn[None, :] * stride_qn
mask = (rm[:, None] < M) & (rn[None, :] < N)
tl.static_print(offsets)
group_offsets = offsets // GROUP_SIZE
tl.static_print("group_offsets", group_offsets)
q_idx = tl.load(q_idx_ptr + offsets)
q_idx = tl.load(q_idx_ptr + offsets, mask=mask)
tl.static_print(q_idx)
# NOTE: Must upcast q_idx to int32 (q_idx is tl.uint8, which does not work for pointer indexing)
q_vals = tl.load(qmap_ptr + q_idx.to(tl.int32))
absmax = tl.load(absmax_ptr + group_offsets)
absmax = tl.load(absmax_ptr + group_offsets, mask=group_offsets < (M * N // GROUP_SIZE))

dq = q_vals * absmax
tl.store(dq_ptr + offsets, dq)
tl.store(dq_ptr + offsets, dq, mask=mask)


def triton_dequant_blockwise(
Expand All @@ -51,6 +54,8 @@ def triton_dequant_blockwise(
dq,
q.stride(0),
q.stride(1),
M,
N,
BLOCK_M=1,
BLOCK_N=group_size,
GROUP_SIZE=group_size,
Expand Down

0 comments on commit f33cff7

Please sign in to comment.