Skip to content

Commit 2102075

Browse files
NickLucchemgoinlsy323
authored
[TPU][V1] Capture multimodal encoder during model compilation (#15051)
Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Siyuan Liu <lsiyuan@google.com>
1 parent 71eda0b commit 2102075

File tree

4 files changed

+327
-48
lines changed

4 files changed

+327
-48
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ source /etc/environment
1717
docker run --privileged --net host --shm-size=16G -it \
1818
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
1919
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
20-
&& python3 -m pip install pytest tpu-info \
20+
&& python3 -m pip install pytest pytest-asyncio tpu-info \
2121
&& python3 -m pip install lm_eval[api]==0.4.4 \
2222
&& export VLLM_USE_V1=1 \
2323
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
@@ -42,6 +42,8 @@ docker run --privileged --net host --shm-size=16G -it \
4242
&& echo TEST_8 \
4343
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
4444
&& echo TEST_9 \
45+
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
46+
&& echo TEST_10 \
4547
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
4648

4749

tests/v1/tpu/test_multimodal.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import openai
4+
import pytest
5+
6+
from vllm import envs
7+
from vllm.multimodal.utils import encode_image_base64, fetch_image
8+
from vllm.platforms import current_platform
9+
10+
from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS
11+
from ...utils import RemoteOpenAIServer
12+
13+
if not envs.VLLM_USE_V1:
14+
pytest.skip(
15+
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
16+
allow_module_level=True,
17+
)
18+
19+
20+
@pytest.fixture(scope="session")
21+
def base64_encoded_image() -> dict[str, str]:
22+
return {
23+
image_url: encode_image_base64(fetch_image(image_url))
24+
for image_url in TEST_IMAGE_URLS
25+
}
26+
27+
28+
@pytest.mark.asyncio
29+
@pytest.mark.skipif(not current_platform.is_tpu(),
30+
reason="This test needs a TPU")
31+
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
32+
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str,
33+
str]):
34+
35+
def whats_in_this_image_msg(b64):
36+
return [{
37+
"role":
38+
"user",
39+
"content": [
40+
{
41+
"type": "text",
42+
"text": "What's in this image?"
43+
},
44+
{
45+
"type": "image_url",
46+
"image_url": {
47+
"url": f"data:image/jpeg;base64,{b64}"
48+
},
49+
},
50+
],
51+
}]
52+
53+
server_args = [
54+
"--max-model-len",
55+
"1024",
56+
"--max-num-seqs",
57+
"16",
58+
"--gpu-memory-utilization",
59+
"0.95",
60+
"--trust-remote-code",
61+
"--max-num-batched-tokens",
62+
"576",
63+
# NOTE: max-num-batched-tokens>=mm_item_size
64+
"--disable_chunked_mm_input",
65+
"--chat-template",
66+
"examples/template_llava.jinja"
67+
]
68+
69+
# Server will pre-compile on first startup (takes a long time).
70+
with RemoteOpenAIServer(model_name, server_args,
71+
max_wait_seconds=600) as remote_server:
72+
client: openai.AsyncOpenAI = remote_server.get_async_client()
73+
74+
# Other requests now should be much faster
75+
for image_url in TEST_IMAGE_URLS:
76+
image_base64 = base64_encoded_image[image_url]
77+
chat_completion_from_base64 = await client.chat.completions\
78+
.create(
79+
model=model_name,
80+
messages=whats_in_this_image_msg(image_base64),
81+
max_completion_tokens=24,
82+
temperature=0.0)
83+
result = chat_completion_from_base64
84+
assert result
85+
choice = result.choices[0]
86+
assert choice.finish_reason == "length"
87+
88+
message = choice.message
89+
message = result.choices[0].message
90+
assert message.content is not None and len(message.content) >= 10
91+
assert message.role == "assistant"

0 commit comments

Comments
 (0)