@@ -117,14 +117,52 @@ def correct_attn_out(
117117 if ctx is None :
118118 ctx = CPTritonContext ()
119119
120- lse = torch .empty_like (lses [0 ])
121-
122- grid = (out .shape [0 ], out .shape [1 ], 1 )
123- regular_args = (out , out , lses , lse , * out .stride (), * lses .stride (), cp_rank )
124- const_args = {
125- "HEAD_DIM" : out .shape [- 1 ],
126- "N_ROUNDED" : lses .shape [0 ],
127- }
120+ # --- Normalize to 3D views ---
121+ if out .ndim == 4 and out .shape [1 ] == 1 :
122+ out = out .squeeze (1 )
123+ assert out .ndim == 3 , f"expected out [B,H,D] or [B,1,H,D], got { tuple (out .shape )} "
124+
125+ if lses .ndim == 4 and lses .shape [- 1 ] == 1 :
126+ lses = lses .squeeze (- 1 )
127+ if lses .ndim == 4 and lses .shape [1 ] == 1 :
128+ lses = lses .squeeze (1 )
129+ assert lses .ndim == 3 , (
130+ f"expected lses [N,B,H] (optionally with a 1-sized extra dim), "
131+ f"got { tuple (lses .shape )} "
132+ )
133+
134+ B , H , D = out .shape
135+ N = lses .shape [0 ]
136+
137+ # Strides after we normalized shapes to 3-D views. The kernel computes
138+ # offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must
139+ # have the same B/H stride layout as a slice of `lses`.
140+ o_sB , o_sH , o_sD = out .stride ()
141+ l_sN , l_sB , l_sH = lses .stride ()
142+
143+ # Allocate LSE with the same B/H strides as `lses` so writes land correctly
144+ # even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze).
145+ lse = torch .empty_strided (
146+ (B , H ), (l_sB , l_sH ), device = lses .device , dtype = lses .dtype
147+ )
148+
149+ # Kernel launch config
150+ grid = (B , H , 1 )
151+
152+ regular_args = (
153+ out ,
154+ out ,
155+ lses ,
156+ lse ,
157+ o_sB ,
158+ o_sH ,
159+ o_sD ,
160+ l_sN ,
161+ l_sB ,
162+ l_sH ,
163+ cp_rank ,
164+ )
165+ const_args = {"HEAD_DIM" : D , "N_ROUNDED" : N }
128166
129167 ctx .call_kernel (_correct_attn_cp_out_kernel , grid , * regular_args , ** const_args )
130168 return out , lse
0 commit comments