44
55import pytest
66
7- from vllm .model_executor .layers .pooler import CLSPool , PoolingType
7+ from vllm .model_executor .layers .pooler import CLSPool , MeanPool , PoolingType
88from vllm .model_executor .models .bert import BertEmbeddingModel
99from vllm .model_executor .models .roberta import RobertaEmbeddingModel
1010from vllm .platforms import current_platform
1414REVISION = os .environ .get ("REVISION" , "main" )
1515
1616MODEL_NAME_ROBERTA = os .environ .get ("MODEL_NAME" ,
17- "intfloat/multilingual-e5-small " )
17+ "intfloat/multilingual-e5-base " )
1818REVISION_ROBERTA = os .environ .get ("REVISION" , "main" )
1919
2020
@@ -40,17 +40,15 @@ def test_model_loading_with_params(vllm_runner):
4040
4141 # asserts on the pooling config files
4242 assert model_config .pooler_config .pooling_type == PoolingType .CLS .name
43- assert model_config .pooler_config .pooling_norm
43+ assert model_config .pooler_config .normalize
4444
4545 # asserts on the tokenizer loaded
4646 assert model_tokenizer .tokenizer_id == "BAAI/bge-base-en-v1.5"
47- assert model_tokenizer .tokenizer_config ["do_lower_case" ]
4847 assert model_tokenizer .tokenizer .model_max_length == 512
4948
5049 def check_model (model ):
5150 assert isinstance (model , BertEmbeddingModel )
52- assert model ._pooler .pooling_type == PoolingType .CLS
53- assert model ._pooler .normalize
51+ assert isinstance (model ._pooler , CLSPool )
5452
5553 vllm_model .apply_model (check_model )
5654
@@ -80,16 +78,15 @@ def test_roberta_model_loading_with_params(vllm_runner):
8078
8179 # asserts on the pooling config files
8280 assert model_config .pooler_config .pooling_type == PoolingType .MEAN .name
83- assert model_config .pooler_config .pooling_norm
81+ assert model_config .pooler_config .normalize
8482
8583 # asserts on the tokenizer loaded
86- assert model_tokenizer .tokenizer_id == "intfloat/multilingual-e5-small "
87- assert not model_tokenizer .tokenizer_config [ "do_lower_case" ]
84+ assert model_tokenizer .tokenizer_id == "intfloat/multilingual-e5-base "
85+ assert model_tokenizer .tokenizer . model_max_length == 512
8886
8987 def check_model (model ):
9088 assert isinstance (model , RobertaEmbeddingModel )
91- assert model ._pooler .pooling_type == PoolingType .MEAN
92- assert model ._pooler .normalize
89+ assert isinstance (model ._pooler , MeanPool )
9390
9491 vllm_model .apply_model (check_model )
9592
0 commit comments