File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -315,6 +315,7 @@ def flash_attn_with_kvcache(
315315    v_descale = None ,
316316    # Version selector 
317317    fa_version : int  =  DEFAULT_FA_VERSION ,
318+     s_aux = None ,
318319):
319320    """ 
320321    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from 
@@ -422,6 +423,8 @@ def flash_attn_with_kvcache(
422423                    "FA2 does not support scheduler_metadata, q_descale, " 
423424                    "k_descale, v_descale" 
424425                )
426+         if  s_aux  is  not   None :
427+             raise  NotImplementedError ("FA2 does not support s_aux" )
425428        out , softmax_lse  =  torch .ops ._vllm_fa2_C .fwd_kvcache (
426429            q , k_cache , v_cache ,
427430            k , v ,             # k_new, v_new 
@@ -466,6 +469,7 @@ def flash_attn_with_kvcache(
466469            num_splits ,          # num_splits 
467470            None ,                # pack_gqa 
468471            0 ,                   # sm_margin 
472+             s_aux ,               # s_aux 
469473        )
470474    else :
471475        raise  ValueError (f"Unsupported FA version: { fa_version }  " )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments