Skip to content

Commit b5769e8

Browse files
committed
fix tests/models/embedding/vision_language/test_phi3v.py
Signed-off-by: pansicheng <sicheng.pan.chn@gmail.com>
1 parent f3f8d8f commit b5769e8

File tree

4 files changed

+110
-14
lines changed

4 files changed

+110
-14
lines changed

tests/entrypoints/openai/test_vision.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import openai
44
import pytest
55
import pytest_asyncio
6+
import requests
7+
from PIL import Image
8+
from transformers import AutoProcessor
69

710
from vllm.multimodal.utils import encode_image_base64, fetch_image
811

@@ -53,11 +56,31 @@ def base64_encoded_image() -> dict[str, str]:
5356
}
5457

5558

59+
def get_hf_prompt_tokens(model_name, content, image_url):
60+
processor = AutoProcessor.from_pretrained(model_name,
61+
trust_remote_code=True,
62+
num_crops=4)
63+
64+
placeholder = "<|image_1|>\n"
65+
messages = [{
66+
"role": "user",
67+
"content": f"{placeholder}{content}",
68+
}]
69+
images = [Image.open(requests.get(image_url, stream=True).raw)]
70+
71+
prompt = processor.tokenizer.apply_chat_template(
72+
messages, tokenize=False, add_generation_prompt=True)
73+
inputs = processor(prompt, images, return_tensors="pt")
74+
75+
return inputs.input_ids.shape[1]
76+
77+
5678
@pytest.mark.asyncio
5779
@pytest.mark.parametrize("model_name", [MODEL_NAME])
5880
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
5981
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
6082
model_name: str, image_url: str):
83+
content_text = "What's in this image?"
6184
messages = [{
6285
"role":
6386
"user",
@@ -70,25 +93,30 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
7093
},
7194
{
7295
"type": "text",
73-
"text": "What's in this image?"
96+
"text": content_text
7497
},
7598
],
7699
}]
77100

101+
max_completion_tokens = 10
78102
# test single completion
79103
chat_completion = await client.chat.completions.create(
80104
model=model_name,
81105
messages=messages,
82-
max_completion_tokens=10,
106+
max_completion_tokens=max_completion_tokens,
83107
logprobs=True,
84108
temperature=0.0,
85109
top_logprobs=5)
86110
assert len(chat_completion.choices) == 1
87111

88112
choice = chat_completion.choices[0]
89113
assert choice.finish_reason == "length"
114+
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
115+
image_url)
90116
assert chat_completion.usage == openai.types.CompletionUsage(
91-
completion_tokens=10, prompt_tokens=774, total_tokens=784)
117+
completion_tokens=max_completion_tokens,
118+
prompt_tokens=hf_prompt_tokens,
119+
total_tokens=hf_prompt_tokens + max_completion_tokens)
92120

93121
message = choice.message
94122
message = chat_completion.choices[0].message
@@ -150,6 +178,7 @@ async def test_single_chat_session_image_base64encoded(
150178
client: openai.AsyncOpenAI, model_name: str, image_url: str,
151179
base64_encoded_image: dict[str, str]):
152180

181+
content_text = "What's in this image?"
153182
messages = [{
154183
"role":
155184
"user",
@@ -163,25 +192,30 @@ async def test_single_chat_session_image_base64encoded(
163192
},
164193
{
165194
"type": "text",
166-
"text": "What's in this image?"
195+
"text": content_text
167196
},
168197
],
169198
}]
170199

200+
max_completion_tokens = 10
171201
# test single completion
172202
chat_completion = await client.chat.completions.create(
173203
model=model_name,
174204
messages=messages,
175-
max_completion_tokens=10,
205+
max_completion_tokens=max_completion_tokens,
176206
logprobs=True,
177207
temperature=0.0,
178208
top_logprobs=5)
179209
assert len(chat_completion.choices) == 1
180210

181211
choice = chat_completion.choices[0]
182212
assert choice.finish_reason == "length"
213+
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
214+
image_url)
183215
assert chat_completion.usage == openai.types.CompletionUsage(
184-
completion_tokens=10, prompt_tokens=774, total_tokens=784)
216+
completion_tokens=max_completion_tokens,
217+
prompt_tokens=hf_prompt_tokens,
218+
total_tokens=hf_prompt_tokens + max_completion_tokens)
185219

186220
message = choice.message
187221
message = chat_completion.choices[0].message

tests/entrypoints/openai/test_vision_embedding.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44
import requests
5+
from PIL import Image
6+
from transformers import AutoProcessor
57

68
from vllm.entrypoints.openai.protocol import EmbeddingResponse
79
from vllm.multimodal.utils import encode_image_base64, fetch_image
@@ -52,11 +54,24 @@ def base64_encoded_image() -> dict[str, str]:
5254
}
5355

5456

57+
def get_hf_prompt_tokens(model_name, content, image_url):
58+
processor = AutoProcessor.from_pretrained(model_name,
59+
trust_remote_code=True,
60+
num_crops=4)
61+
62+
placeholder = "<|image_1|> "
63+
prompt = f"{placeholder}{content}"
64+
images = [Image.open(requests.get(image_url, stream=True).raw)]
65+
inputs = processor(prompt, images, return_tensors="pt")
66+
return inputs.input_ids.shape[1]
67+
68+
5569
@pytest.mark.asyncio
5670
@pytest.mark.parametrize("model_name", [MODEL_NAME])
5771
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
5872
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
5973
image_url: str):
74+
content_text = "Represent the given image."
6075
messages = [{
6176
"role":
6277
"user",
@@ -69,7 +84,7 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
6984
},
7085
{
7186
"type": "text",
72-
"text": "Represent the given image."
87+
"text": content_text
7388
},
7489
],
7590
}]
@@ -85,9 +100,12 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
85100
response.raise_for_status()
86101
embeddings = EmbeddingResponse.model_validate(response.json())
87102

103+
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
104+
image_url)
105+
88106
assert embeddings.id is not None
89107
assert len(embeddings.data) == 1
90108
assert len(embeddings.data[0].embedding) == 3072
91109
assert embeddings.usage.completion_tokens == 0
92-
assert embeddings.usage.prompt_tokens == 763
93-
assert embeddings.usage.total_tokens == 763
110+
assert embeddings.usage.prompt_tokens == hf_prompt_tokens
111+
assert embeddings.usage.total_tokens == hf_prompt_tokens

tests/models/embedding/vision_language/test_phi3v.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
import pytest
44
import torch.nn.functional as F
5+
from PIL import Image
6+
7+
from vllm.assets.base import get_vllm_public_assets
8+
from vllm.assets.image import VLM_IMAGES_DIR
59

610
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
711
from ....utils import large_gpu_test
@@ -112,6 +116,15 @@ def test_models_image(
112116
(text, asset.pil_image)
113117
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
114118
]
119+
# add cases for special_tokens
120+
input_texts_images.append((
121+
"\n<s><|user|>\n <|image_1|>\n\t <s>"
122+
"Represent the given image for classification<|end|>"
123+
"\n<|assistant|>\n",
124+
Image.open(
125+
get_vllm_public_assets(filename="cherry_blossom.jpg",
126+
s3_prefix=VLM_IMAGES_DIR)),
127+
))
115128
input_texts = [text for text, _ in input_texts_images]
116129
input_images = [image for _, image in input_texts_images]
117130

vllm/model_executor/models/phi3v.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
import re
1718
from collections.abc import Iterable, Mapping, Sequence
1819
from functools import cached_property
1920
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
@@ -428,10 +429,6 @@ def _get_prompt_updates(
428429
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
429430
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
430431

431-
tokenizer = self.info.get_tokenizer()
432-
bos_token_id = tokenizer.bos_token_id
433-
assert isinstance(bos_token_id, int)
434-
435432
def get_replacement_phi3v(item_idx: int):
436433
images = mm_items.get_items(
437434
"image", (ImageEmbeddingItems, ImageProcessorItems))
@@ -449,7 +446,7 @@ def get_replacement_phi3v(item_idx: int):
449446
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
450447

451448
return PromptUpdateDetails(
452-
full=image_tokens + [bos_token_id],
449+
full=image_tokens,
453450
features=image_tokens,
454451
)
455452

@@ -469,6 +466,40 @@ def _apply_prompt_updates(
469466
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
470467
mm_item_counts: Mapping[str, int],
471468
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
469+
# align to hf behavior when there are images
470+
if len(mm_item_counts):
471+
tokenizer = self.info.get_tokenizer()
472+
# to decode token_ids to the original text, we need to
473+
# 1. remove the first bos token
474+
# 2. remove space after each special token
475+
# introduced by the tokenizer
476+
if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
477+
token_ids = token_ids[1:]
478+
text = tokenizer.decode(token_ids)
479+
for special_tokens in tokenizer.special_tokens_map.values():
480+
if isinstance(special_tokens, str):
481+
text = text.replace(f"{special_tokens} ", special_tokens)
482+
elif isinstance(special_tokens, list):
483+
for special_token in special_tokens:
484+
text = text.replace(f"{special_token} ", special_token)
485+
# perform hf behavior
486+
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
487+
pattern = r"<\|image_\d+\|>"
488+
prompt_chunks = [
489+
tokenizer(chunk).input_ids
490+
for chunk in re.split(pattern, text)
491+
]
492+
image_tags = [
493+
tokenizer(chunk, add_special_tokens=False).input_ids
494+
for chunk in re.findall(pattern, text)
495+
]
496+
if len(prompt_chunks) > len(image_tags):
497+
image_tags.append([])
498+
token_ids = [
499+
e for sublist in zip(prompt_chunks, image_tags)
500+
for ele in sublist for e in ele
501+
]
502+
472503
token_ids, text, placeholders = super()._apply_prompt_updates(
473504
token_ids=token_ids,
474505
mm_prompt_updates=mm_prompt_updates,

0 commit comments

Comments
 (0)