File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -142,6 +142,7 @@ def flash_attn_varlen_func(
142142 q_descale = None ,
143143 k_descale = None ,
144144 v_descale = None ,
145+ num_splits : int = 0 ,
145146 # Version selector
146147 fa_version : int = DEFAULT_FA_VERSION ,
147148):
@@ -224,6 +225,8 @@ def flash_attn_varlen_func(
224225 "FA2 does not support scheduler_metadata, q_descale, "
225226 "k_descale, v_descale"
226227 )
228+ if num_splits > 1 :
229+ raise NotImplementedError ("FA2 does not support num_splits > 1" )
227230 out , softmax_lse = torch .ops ._vllm_fa2_C .varlen_fwd (
228231 q , k , v ,
229232 out ,
@@ -270,7 +273,7 @@ def flash_attn_varlen_func(
270273 softcap ,
271274 True , # rotary_interleaved
272275 scheduler_metadata ,
273- 0 , # num_splits
276+ num_splits ,
274277 None , # pack_gqa
275278 0 , # sm_margin
276279 )
You can’t perform that action at this time.
0 commit comments