33import pytest
44from transformers import AutoTokenizer
55
6- from vllm .config import VisionLanguageConfig
76from vllm .multimodal .utils import rescale_image_size
87from vllm .sequence import SampleLogprobs
98
2120 "USER: <image>\n What's in this image?\n ASSISTANT:" ,
2221})
2322
23+ IMAGE_TOKEN_ID = 32000
2424
25- def iter_llava_configs (model_name : str ):
26- image_hw_to_feature_size = {
27- (336 , 336 ): 576 ,
28- }
29-
30- for (h , w ), f in image_hw_to_feature_size .items ():
31- input_shape = (1 , 3 , h , w )
32- yield (model_name ,
33- VisionLanguageConfig (image_feature_size = f ,
34- image_token_id = 32000 ,
35- image_input_shape = input_shape ))
36-
37-
38- model_and_vl_config = [
39- * iter_llava_configs ("llava-hf/llava-1.5-7b-hf" ),
40- ]
25+ models = ["llava-hf/llava-1.5-7b-hf" ]
4126
4227
4328def vllm_to_hf_output (vllm_output : Tuple [List [int ], str ,
4429 Optional [SampleLogprobs ]],
45- vlm_config : VisionLanguageConfig , model_id : str ):
46- """Sanitize vllm output to be comparable with hf output.
47- The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
48- x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
49- It also reduces `output_str` from "<image><image>bla" to "bla".
50- """
30+ model : str ):
31+ """Sanitize vllm output to be comparable with hf output."""
5132 output_ids , output_str , out_logprobs = vllm_output
52- image_token_id = vlm_config .image_token_id
5333
54- tokenizer = AutoTokenizer .from_pretrained (model_id )
55- image_token_str = tokenizer .decode (image_token_id )
34+ tokenizer = AutoTokenizer .from_pretrained (model )
5635 eos_token_id = tokenizer .eos_token_id
5736
5837 hf_output_ids = [
5938 token_id for idx , token_id in enumerate (output_ids )
60- if token_id != image_token_id or output_ids [idx - 1 ] != image_token_id
39+ if token_id != IMAGE_TOKEN_ID or output_ids [idx - 1 ] != IMAGE_TOKEN_ID
6140 ]
6241
63- hf_output_str = output_str \
64- .replace (image_token_str * vlm_config .image_feature_size , "" )
65- assert hf_output_str [0 ] == " "
66- hf_output_str = hf_output_str [1 :]
42+ assert output_str [0 ] == " "
43+ hf_output_str = output_str [1 :]
6744 if hf_output_ids [- 1 ] == eos_token_id :
6845 hf_output_str = hf_output_str + tokenizer .decode (eos_token_id )
6946
@@ -74,7 +51,7 @@ def run_test(
7451 hf_runner : Type [HfRunner ],
7552 vllm_runner : Type [VllmRunner ],
7653 image_assets : _ImageAssets ,
77- model_and_config : Tuple [ str , VisionLanguageConfig ] ,
54+ model ,
7855 * ,
7956 size_factors : List [float ],
8057 dtype : str ,
@@ -92,7 +69,6 @@ def run_test(
9269 Note, the text input is also adjusted to abide by vllm contract.
9370 The text output is sanitized to be able to compare with hf.
9471 """
95- model_id , vlm_config = model_and_config
9672 images = [asset .pil_image for asset in image_assets ]
9773
9874 inputs_per_image = [(
@@ -106,12 +82,11 @@ def run_test(
10682 # will hurt multiprocessing backend with fork method (the default method).
10783
10884 # max_model_len should be greater than image_feature_size
109- with vllm_runner (model_id ,
85+ with vllm_runner (model ,
11086 dtype = dtype ,
11187 tensor_parallel_size = tensor_parallel_size ,
11288 distributed_executor_backend = distributed_executor_backend ,
113- enforce_eager = True ,
114- ** vlm_config .as_cli_args_dict ()) as vllm_model :
89+ enforce_eager = True ) as vllm_model :
11590 vllm_outputs_per_image = [
11691 vllm_model .generate_greedy_logprobs (prompts ,
11792 max_tokens ,
@@ -120,7 +95,7 @@ def run_test(
12095 for prompts , images in inputs_per_image
12196 ]
12297
123- with hf_runner (model_id , dtype = dtype , is_vision_model = True ) as hf_model :
98+ with hf_runner (model , dtype = dtype , is_vision_model = True ) as hf_model :
12499 hf_outputs_per_image = [
125100 hf_model .generate_greedy_logprobs_limit (prompts ,
126101 max_tokens ,
@@ -136,15 +111,15 @@ def run_test(
136111 check_logprobs_close (
137112 outputs_0_lst = hf_outputs ,
138113 outputs_1_lst = [
139- vllm_to_hf_output (vllm_output , vlm_config , model_id )
114+ vllm_to_hf_output (vllm_output , model )
140115 for vllm_output in vllm_outputs
141116 ],
142117 name_0 = "hf" ,
143118 name_1 = "vllm" ,
144119 )
145120
146121
147- @pytest .mark .parametrize ("model_and_config " , model_and_vl_config )
122+ @pytest .mark .parametrize ("model " , models )
148123@pytest .mark .parametrize (
149124 "size_factors" ,
150125 [
@@ -161,14 +136,13 @@ def run_test(
161136@pytest .mark .parametrize ("dtype" , ["half" ])
162137@pytest .mark .parametrize ("max_tokens" , [128 ])
163138@pytest .mark .parametrize ("num_logprobs" , [5 ])
164- def test_models (hf_runner , vllm_runner , image_assets , model_and_config ,
165- size_factors , dtype : str , max_tokens : int ,
166- num_logprobs : int ) -> None :
139+ def test_models (hf_runner , vllm_runner , image_assets , model , size_factors ,
140+ dtype : str , max_tokens : int , num_logprobs : int ) -> None :
167141 run_test (
168142 hf_runner ,
169143 vllm_runner ,
170144 image_assets ,
171- model_and_config ,
145+ model ,
172146 size_factors = size_factors ,
173147 dtype = dtype ,
174148 max_tokens = max_tokens ,
0 commit comments