File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -258,14 +258,21 @@ def __init__(self, config: ModernBertConfig):
258258 super ().__init__ ()
259259 self .dense = nn .Linear (config .hidden_size , config .hidden_size ,
260260 config .classifier_bias )
261+ self .pooling_type = config .classifier_pooling
261262 self .act = nn .GELU ()
262263 self .norm = nn .LayerNorm (config .hidden_size ,
263264 eps = config .norm_eps ,
264265 bias = config .norm_bias )
265266
266267 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
267268 pooled_output = hidden_states
268- pooled_output = pooled_output .mean (dim = 0 , keepdim = False )
269+ if self .pooling_type == "mean" :
270+ pooled_output = pooled_output .mean (dim = 0 , keepdim = False )
271+ elif self .pooling_type == "cls" :
272+ pooled_output = pooled_output [0 , :]
273+ else :
274+ raise ValueError ("Pooling type should be either `cls` or `mean`, "
275+ f"but got { self .pooling_type } " )
269276 pooled_output = self .norm (self .act (self .dense (pooled_output )))
270277 return pooled_output
271278
You can’t perform that action at this time.
0 commit comments