File tree Expand file tree Collapse file tree 5 files changed +13
-2
lines changed 
language/pooling_mteb_test 
vllm/model_executor/layers Expand file tree Collapse file tree 5 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -191,7 +191,7 @@ def mteb_test_embed_models(
191191    with  vllm_runner (
192192        model_info .name ,
193193        runner = "pooling" ,
194-         max_model_len = None ,
194+         max_model_len = model_info . max_model_len ,
195195        ** vllm_extra_kwargs ,
196196    ) as  vllm_model :
197197        model_config  =  vllm_model .llm .llm_engine .model_config 
Original file line number Diff line number Diff line change 2525        mteb_score = 0.824413164 ,
2626        architecture = "XLMRobertaModel" ,
2727        is_matryoshka = True ,
28+         # The default max length of the model is 8194, which will crash 
29+         # CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid 
30+         # avoid this issue. 
31+         max_model_len = 8192 ,
32+         dtype = "float32" ,
2833    )
2934]
3035
Original file line number Diff line number Diff line change 2323        architecture = "Gemma3TextModel" ,
2424        mteb_score = 0.7473819294684156 ,
2525        enable_test = True ,
26+         dtype = "float32" ,
2627    ),
2728]
2829
Original file line number Diff line number Diff line change @@ -369,6 +369,7 @@ class ModelInfo:
369369    name : str 
370370    architecture : str  =  "" 
371371    dtype : str  =  "auto" 
372+     max_model_len : Optional [int ] =  None 
372373    hf_dtype : str  =  "float32" 
373374    hf_overrides : Optional [dict [str , Any ]] =  None 
374375    default_pooling_type : str  =  "" 
Original file line number Diff line number Diff line change @@ -318,7 +318,11 @@ def forward_static(
318318        """PyTorch-native implementation equivalent to forward().""" 
319319        orig_dtype  =  x .dtype 
320320        if  residual  is  not   None :
321-             x  =  x  +  residual .float () if  orig_dtype  ==  torch .float16  else  x  +  residual 
321+             x  =  (
322+                 x .float () +  residual .float ()
323+                 if  orig_dtype  ==  torch .float16 
324+                 else  x  +  residual 
325+             )
322326            residual  =  x 
323327
324328        x  =  x .float ()
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments