Skip to content

Commit 0cd103e

Browse files
authored
CP: make correct_attn_out robust to 4‑D views and fix Triton arg binding (#26509)
Signed-off-by: Huamin Li <3ericli@gmail.com>
1 parent 5be7ca1 commit 0cd103e

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

vllm/attention/ops/common.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,52 @@ def correct_attn_out(
117117
if ctx is None:
118118
ctx = CPTritonContext()
119119

120-
lse = torch.empty_like(lses[0])
121-
122-
grid = (out.shape[0], out.shape[1], 1)
123-
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank)
124-
const_args = {
125-
"HEAD_DIM": out.shape[-1],
126-
"N_ROUNDED": lses.shape[0],
127-
}
120+
# --- Normalize to 3D views ---
121+
if out.ndim == 4 and out.shape[1] == 1:
122+
out = out.squeeze(1)
123+
assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}"
124+
125+
if lses.ndim == 4 and lses.shape[-1] == 1:
126+
lses = lses.squeeze(-1)
127+
if lses.ndim == 4 and lses.shape[1] == 1:
128+
lses = lses.squeeze(1)
129+
assert lses.ndim == 3, (
130+
f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
131+
f"got {tuple(lses.shape)}"
132+
)
133+
134+
B, H, D = out.shape
135+
N = lses.shape[0]
136+
137+
# Strides after we normalized shapes to 3-D views. The kernel computes
138+
# offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
139+
# have the same B/H stride layout as a slice of `lses`.
140+
o_sB, o_sH, o_sD = out.stride()
141+
l_sN, l_sB, l_sH = lses.stride()
142+
143+
# Allocate LSE with the same B/H strides as `lses` so writes land correctly
144+
# even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
145+
lse = torch.empty_strided(
146+
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
147+
)
148+
149+
# Kernel launch config
150+
grid = (B, H, 1)
151+
152+
regular_args = (
153+
out,
154+
out,
155+
lses,
156+
lse,
157+
o_sB,
158+
o_sH,
159+
o_sD,
160+
l_sN,
161+
l_sB,
162+
l_sH,
163+
cp_rank,
164+
)
165+
const_args = {"HEAD_DIM": D, "N_ROUNDED": N}
128166

129167
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args)
130168
return out, lse

0 commit comments

Comments
 (0)