File tree Expand file tree Collapse file tree 2 files changed +24
-1
lines changed Expand file tree Collapse file tree 2 files changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -399,6 +399,9 @@ def as_reward_model(cls: _T) -> _T:
399399    # Lazy import 
400400    from  vllm .model_executor .layers .pooler  import  DispatchPooler , Pooler 
401401
402+     from  .interfaces_base  import  default_pooling_type 
403+ 
404+     @default_pooling_type ("ALL" ) 
402405    class  ModelForReward (_create_pooling_model_cls (cls )):
403406        def  _init_pooler (self , vllm_config : "VllmConfig" , prefix : str  =  "" ):
404407            pooler_config  =  vllm_config .model_config .pooler_config 
Original file line number Diff line number Diff line change @@ -3622,8 +3622,28 @@ def _dummy_pooler_run(
36223622        hidden_states : torch .Tensor ,
36233623    ) ->  PoolerOutput :
36243624        # Find the task that has the largest output for subsequent steps 
3625+         supported_pooling_tasks  =  self .get_supported_pooling_tasks ()
3626+ 
3627+         if  not  supported_pooling_tasks :
3628+             if  self .scheduler_config .chunked_prefill_enabled :
3629+                 raise  RuntimeError (
3630+                     f"Model { self .model_config .model }   does not support " 
3631+                     "any pooling tasks with chunked prefill enabled. " 
3632+                     "Please add --no-enable-chunked-prefill to your " 
3633+                     "config or CLI args. See " 
3634+                     "https://docs.vllm.ai/en/latest/models/pooling_models.html " 
3635+                     "to learn more." 
3636+                 )
3637+             else :
3638+                 raise  RuntimeError (
3639+                     f"Model { self .model_config .model }   does not support " 
3640+                     "any pooling tasks. See " 
3641+                     "https://docs.vllm.ai/en/latest/models/pooling_models.html " 
3642+                     "to learn more." 
3643+                 )
3644+ 
36253645        output_size  =  dict [PoolingTask , float ]()
3626-         for  task  in  self . get_supported_pooling_tasks () :
3646+         for  task  in  supported_pooling_tasks :
36273647            # Run a full batch with each task to ensure none of them OOMs 
36283648            output  =  self ._dummy_pooler_run_task (hidden_states , task )
36293649            output_size [task ] =  sum (o .nbytes  for  o  in  output )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments