121121 (4224 , 4224 ),
122122 ],
123123)
124+ # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
124125@pytest .mark .parametrize (
125- "cp_world_size" , [4 , 2 , 1 ], # 1 means disabling cp
126+ "cp_world_size,cp_rank,cp_tot_seqlen_k_offset" ,
127+ [
128+ (8 ,0 ,1 ),
129+ (8 ,7 ,0 ),
130+ (4 ,3 ,2 ),
131+ (2 ,0 ,0 ),
132+ (1 ,0 ,0 ), # 1 means disabling cp
133+ ],
126134)
127- #@pytest.mark.parametrize('seqlen_q,seqlen_k', [(1, 1)])
128135def test_flash_attn_output (
129136 seqlen_q , seqlen_k , d , causal , local , softcap , V_colmajor , deterministic , has_qv_ , mha_type , dtype , test_sink ,
130- cp_world_size ,
137+ cp_world_size , cp_rank , cp_tot_seqlen_k_offset
131138):
132139 if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch .float8_e4m3fn ):
133140 pytest .skip ("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn" )
@@ -157,7 +164,8 @@ def test_flash_attn_output(
157164 s_aux = torch .randn (nheads , device = device , dtype = torch .bfloat16 ) * 4 if test_sink else None
158165 # s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None
159166 # print("s_aux ", s_aux)
160- cp_rank = 0
167+ cp_tot_seqlen_k = seqlen_k * cp_world_size + cp_tot_seqlen_k_offset
168+ cp_tot_seqlen_k = torch .full ((batch_size ,), cp_tot_seqlen_k , device = device , dtype = torch .int32 )
161169 if test_sink :
162170 dv_vals = [d ]
163171 for dv in dv_vals :
@@ -175,7 +183,7 @@ def test_flash_attn_output(
175183 else :
176184 qv_ref = None
177185 # Put window_size after QKV randn so that window_size changes from test to test
178- window_size = (- 1 , - 1 ) if not local else torch .randint (0 , seqlen_k * cp_world_size , (2 ,))
186+ window_size = (- 1 , - 1 ) if not local else torch .randint (0 , cp_tot_seqlen_k [ 0 ] , (2 ,))
179187 # window_size = (-1, -1) if not local else (16, 0)
180188 if dtype == torch .float8_e4m3fn :
181189 q_descale , k_descale , v_descale = [torch .rand (batch_size , nheads_kv , device = device , dtype = torch .float32 ) * 2 for _ in range (3 )]
@@ -199,6 +207,7 @@ def test_flash_attn_output(
199207 s_aux = s_aux ,
200208 cp_world_size = cp_world_size ,
201209 cp_rank = cp_rank ,
210+ cp_tot_seqlen_k = cp_tot_seqlen_k ,
202211 )
203212 out_pt , attn_pt = attention_ref (
204213 q_ref ,
@@ -217,6 +226,7 @@ def test_flash_attn_output(
217226 s_aux = s_aux ,
218227 cp_world_size = cp_world_size ,
219228 cp_rank = cp_rank ,
229+ cp_tot_seqlen_k = cp_tot_seqlen_k ,
220230 )
221231
222232 # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()
@@ -251,6 +261,7 @@ def test_flash_attn_output(
251261 s_aux = s_aux ,
252262 cp_world_size = cp_world_size ,
253263 cp_rank = cp_rank ,
264+ cp_tot_seqused_k = cp_tot_seqlen_k ,
254265 )
255266 print ("Pack GQA =" , pack_gqa )
256267 print ("Num splits =" , num_splits )
@@ -378,7 +389,7 @@ def test_flash_attn_output(
378389 ],
379390)
380391def test_flash_attn_varlen_output (
381- seqlen_q , seqlen_k , d , add_unused_qkv , causal , local , softcap , deterministic , has_qv_ , mha_type , dtype , test_sink
392+ seqlen_q , seqlen_k , d , add_unused_qkv , causal , local , softcap , deterministic , has_qv_ , mha_type , dtype , test_sink ,
382393):
383394 if has_qv_ and (d != 64 or dtype == torch .float8_e4m3fn ):
384395 pytest .skip ("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)" )
0 commit comments