File tree Expand file tree Collapse file tree 1 file changed +10
-5
lines changed Expand file tree Collapse file tree 1 file changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -178,7 +178,8 @@ def _decode_att_m_fwd(
178178    page_size ,
179179    logit_cap ,
180180):
181-     BLOCK  =  64 
181+     BLOCK  =  64  if  not  is_hip_  else  8 
182+ 
182183    NUM_KV_SPLITS  =  num_kv_splits 
183184    Lk  =  k_buffer .shape [- 1 ]
184185    Lv  =  v_buffer .shape [- 1 ]
@@ -188,7 +189,9 @@ def _decode_att_m_fwd(
188189    grid  =  (batch , head_num , NUM_KV_SPLITS )
189190    kv_group_num  =  q .shape [1 ] //  k_buffer .shape [- 2 ]
190191
191-     num_warps  =  4  if  kv_group_num  ==  1  else  2 
192+     num_warps  =  4 
193+     if  kv_group_num  !=  1 :
194+         num_warps  =  1  if  is_hip_  else  2 
192195
193196    BLOCK_DMODEL  =  triton .next_power_of_2 (Lk )
194197    BLOCK_DV  =  triton .next_power_of_2 (Lv )
@@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
418421    )
419422
420423    extra_kargs  =  {}
424+     num_stages  =  2 
421425    if  is_hip_ :
422-         # https://rocm.docs.amd.com/en/docs-6.2.0 /how-to/llm-fine-tuning -optimization/optimizing -triton-kernel.html  
426+         # https://rocm.docs.amd.com/en/latest /how-to/rocm-for-ai/inference -optimization/workload.html#mi300x -triton-kernel-performance-optimization  
423427        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py 
424428        extra_kargs  =  {
425-             "waves_per_eu" : 4 ,
429+             "waves_per_eu" : 1 ,
426430            "matrix_instr_nonkdim" : 16 ,
427431            "kpack" : 2 
428432        }
433+         num_stages  =  1 
429434
430435    _fwd_grouped_kernel_stage1 [grid ](
431436        q ,
@@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
456461        PAGE_SIZE = page_size ,
457462        logit_cap = logit_cap ,
458463        num_warps = 4 ,
459-         num_stages = 2 ,
464+         num_stages = num_stages ,
460465        Lk = Lk ,
461466        Lv = Lv ,
462467        ** extra_kargs ,
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments