Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import openai
import pytest
import pytest_asyncio
import requests
from PIL import Image
from transformers import AutoProcessor

from vllm.multimodal.utils import encode_image_base64, fetch_image

Expand Down Expand Up @@ -53,11 +56,31 @@ def base64_encoded_image() -> dict[str, str]:
}


def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True,
num_crops=4)

placeholder = "<|image_1|>\n"
messages = [{
"role": "user",
"content": f"{placeholder}{content}",
}]
images = [Image.open(requests.get(image_url, stream=True).raw)]

prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, images, return_tensors="pt")

return inputs.input_ids.shape[1]


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str):
content_text = "What's in this image?"
messages = [{
"role":
"user",
Expand All @@ -70,25 +93,30 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
},
{
"type": "text",
"text": "What's in this image?"
"text": content_text
},
],
}]

max_completion_tokens = 10
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
max_completion_tokens=max_completion_tokens,
logprobs=True,
temperature=0.0,
top_logprobs=5)
assert len(chat_completion.choices) == 1

choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=774, total_tokens=784)
completion_tokens=max_completion_tokens,
prompt_tokens=hf_prompt_tokens,
total_tokens=hf_prompt_tokens + max_completion_tokens)

message = choice.message
message = chat_completion.choices[0].message
Expand Down Expand Up @@ -150,6 +178,7 @@ async def test_single_chat_session_image_base64encoded(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: dict[str, str]):

content_text = "What's in this image?"
messages = [{
"role":
"user",
Expand All @@ -163,25 +192,30 @@ async def test_single_chat_session_image_base64encoded(
},
{
"type": "text",
"text": "What's in this image?"
"text": content_text
},
],
}]

max_completion_tokens = 10
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
max_completion_tokens=max_completion_tokens,
logprobs=True,
temperature=0.0,
top_logprobs=5)
assert len(chat_completion.choices) == 1

choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=774, total_tokens=784)
completion_tokens=max_completion_tokens,
prompt_tokens=hf_prompt_tokens,
total_tokens=hf_prompt_tokens + max_completion_tokens)

message = choice.message
message = chat_completion.choices[0].message
Expand Down
24 changes: 21 additions & 3 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest
import requests
from PIL import Image
from transformers import AutoProcessor

from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.multimodal.utils import encode_image_base64, fetch_image
Expand Down Expand Up @@ -52,11 +54,24 @@ def base64_encoded_image() -> dict[str, str]:
}


def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True,
num_crops=4)

placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}"
images = [Image.open(requests.get(image_url, stream=True).raw)]
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
image_url: str):
content_text = "Represent the given image."
messages = [{
"role":
"user",
Expand All @@ -69,7 +84,7 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
},
{
"type": "text",
"text": "Represent the given image."
"text": content_text
},
],
}]
Expand All @@ -85,9 +100,12 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
response.raise_for_status()
embeddings = EmbeddingResponse.model_validate(response.json())

hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
image_url)

assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 3072
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 763
assert embeddings.usage.total_tokens == 763
assert embeddings.usage.prompt_tokens == hf_prompt_tokens
assert embeddings.usage.total_tokens == hf_prompt_tokens
13 changes: 13 additions & 0 deletions tests/models/embedding/vision_language/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import pytest
import torch.nn.functional as F
from PIL import Image

from vllm.assets.base import get_vllm_public_assets
from vllm.assets.image import VLM_IMAGES_DIR

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
Expand Down Expand Up @@ -112,6 +116,15 @@ def test_models_image(
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
# add cases for special_tokens
input_texts_images.append((
"\n<s><|user|>\n <|image_1|>\n\t <s>"
"Represent the given image for classification<|end|>"
"\n<|assistant|>\n",
Image.open(
get_vllm_public_assets(filename="cherry_blossom.jpg",
s3_prefix=VLM_IMAGES_DIR)),
))
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]

Expand Down
41 changes: 36 additions & 5 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
Expand Down Expand Up @@ -428,10 +429,6 @@ def _get_prompt_updates(
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore

tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)

def get_replacement_phi3v(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
Expand All @@ -449,7 +446,7 @@ def get_replacement_phi3v(item_idx: int):
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens

return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
full=image_tokens,
features=image_tokens,
)

Expand All @@ -469,6 +466,40 @@ def _apply_prompt_updates(
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images
if len(mm_item_counts):
tokenizer = self.info.get_tokenizer()
# to decode token_ids to the original text, we need to
# 1. remove the first bos token
# 2. remove space after each special token
# introduced by the tokenizer
if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
token_ids = token_ids[1:]
text = tokenizer.decode(token_ids)
for special_tokens in tokenizer.special_tokens_map.values():
if isinstance(special_tokens, str):
text = text.replace(f"{special_tokens} ", special_tokens)
elif isinstance(special_tokens, list):
for special_token in special_tokens:
text = text.replace(f"{special_token} ", special_token)
# perform hf behavior
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
pattern = r"<\|image_\d+\|>"
prompt_chunks = [
tokenizer(chunk).input_ids
for chunk in re.split(pattern, text)
]
image_tags = [
tokenizer(chunk, add_special_tokens=False).input_ids
for chunk in re.findall(pattern, text)
]
if len(prompt_chunks) > len(image_tags):
image_tags.append([])
token_ids = [
e for sublist in zip(prompt_chunks, image_tags)
for ele in sublist for e in ele
]

token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates,
Expand Down