|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | + | 
|  | 3 | +import openai | 
|  | 4 | +import pytest | 
|  | 5 | + | 
|  | 6 | +from vllm import envs | 
|  | 7 | +from vllm.multimodal.utils import encode_image_base64, fetch_image | 
|  | 8 | +from vllm.platforms import current_platform | 
|  | 9 | + | 
|  | 10 | +from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS | 
|  | 11 | +from ...utils import RemoteOpenAIServer | 
|  | 12 | + | 
|  | 13 | +if not envs.VLLM_USE_V1: | 
|  | 14 | +    pytest.skip( | 
|  | 15 | +        "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", | 
|  | 16 | +        allow_module_level=True, | 
|  | 17 | +    ) | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +@pytest.fixture(scope="session") | 
|  | 21 | +def base64_encoded_image() -> dict[str, str]: | 
|  | 22 | +    return { | 
|  | 23 | +        image_url: encode_image_base64(fetch_image(image_url)) | 
|  | 24 | +        for image_url in TEST_IMAGE_URLS | 
|  | 25 | +    } | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +@pytest.mark.asyncio | 
|  | 29 | +@pytest.mark.skipif(not current_platform.is_tpu(), | 
|  | 30 | +                    reason="This test needs a TPU") | 
|  | 31 | +@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) | 
|  | 32 | +async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, | 
|  | 33 | +                                                                        str]): | 
|  | 34 | + | 
|  | 35 | +    def whats_in_this_image_msg(b64): | 
|  | 36 | +        return [{ | 
|  | 37 | +            "role": | 
|  | 38 | +            "user", | 
|  | 39 | +            "content": [ | 
|  | 40 | +                { | 
|  | 41 | +                    "type": "text", | 
|  | 42 | +                    "text": "What's in this image?" | 
|  | 43 | +                }, | 
|  | 44 | +                { | 
|  | 45 | +                    "type": "image_url", | 
|  | 46 | +                    "image_url": { | 
|  | 47 | +                        "url": f"data:image/jpeg;base64,{b64}" | 
|  | 48 | +                    }, | 
|  | 49 | +                }, | 
|  | 50 | +            ], | 
|  | 51 | +        }] | 
|  | 52 | + | 
|  | 53 | +    server_args = [ | 
|  | 54 | +        "--max-model-len", | 
|  | 55 | +        "1024", | 
|  | 56 | +        "--max-num-seqs", | 
|  | 57 | +        "16", | 
|  | 58 | +        "--gpu-memory-utilization", | 
|  | 59 | +        "0.95", | 
|  | 60 | +        "--trust-remote-code", | 
|  | 61 | +        "--max-num-batched-tokens", | 
|  | 62 | +        "576", | 
|  | 63 | +        # NOTE: max-num-batched-tokens>=mm_item_size | 
|  | 64 | +        "--disable_chunked_mm_input", | 
|  | 65 | +        "--chat-template", | 
|  | 66 | +        "examples/template_llava.jinja" | 
|  | 67 | +    ] | 
|  | 68 | + | 
|  | 69 | +    # Server will pre-compile on first startup (takes a long time). | 
|  | 70 | +    with RemoteOpenAIServer(model_name, server_args, | 
|  | 71 | +                            max_wait_seconds=600) as remote_server: | 
|  | 72 | +        client: openai.AsyncOpenAI = remote_server.get_async_client() | 
|  | 73 | + | 
|  | 74 | +        # Other requests now should be much faster | 
|  | 75 | +        for image_url in TEST_IMAGE_URLS: | 
|  | 76 | +            image_base64 = base64_encoded_image[image_url] | 
|  | 77 | +            chat_completion_from_base64 = await client.chat.completions\ | 
|  | 78 | +                .create( | 
|  | 79 | +                model=model_name, | 
|  | 80 | +                messages=whats_in_this_image_msg(image_base64), | 
|  | 81 | +                max_completion_tokens=24, | 
|  | 82 | +                temperature=0.0) | 
|  | 83 | +            result = chat_completion_from_base64 | 
|  | 84 | +            assert result | 
|  | 85 | +            choice = result.choices[0] | 
|  | 86 | +            assert choice.finish_reason == "length" | 
|  | 87 | + | 
|  | 88 | +            message = choice.message | 
|  | 89 | +            message = result.choices[0].message | 
|  | 90 | +            assert message.content is not None and len(message.content) >= 10 | 
|  | 91 | +            assert message.role == "assistant" | 
0 commit comments