Skip to content

Conversation

@bbeckca
Copy link
Contributor

@bbeckca bbeckca commented Jul 28, 2025

Purpose

This PR migrates LlavaImageInputs from a TypedDict-based definition to a structured TensorSchema model with runtime shape validation. This brings it in line with recent changes to Phi3VImagePixelInputs, and is part of a broader effort to improve input contract enforcement and debug-ability across multi-modal models.

Test Plan

Confirm validation works via standalone tests in tests/standalone_test/test_tensor_schema.py and rely on CI to check integration.

Test Result

(venv) benjibeck@Benjis-MBP vllm % python3 -m pytest tests/standalone_tests/test_tensor_schema.py -v --log-cli-level=DEBUG
======================================================================================================================================================================================= test session starts ========================================================================================================================================================================================
platform darwin -- Python 3.9.6, pytest-8.4.1, pluggy-1.6.0 -- /Users/benjibeck/Projects/vllm/venv/bin/python3
cachedir: .pytest_cache
rootdir: /Users/benjibeck/Projects/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 14 items                                                                                                                                                                                                                                                                                                                                                                                 

tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_valid_tensor PASSED                                                                                                                                                                                                                                                                                                         [  7%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_optional_fields PASSED                                                                                                                                                                                                                                                                                                      [ 14%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_constant_dim_failure PASSED                                                                                                                                                                                                                                                                                                 [ 21%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_symbolic_dim_mismatch PASSED                                                                                                                                                                                                                                                                                                [ 28%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_list_tensor_valid PASSED                                                                                                                                                                                                                                                                                                    [ 35%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_variable_patch_counts_valid PASSED                                                                                                                                                                                                                                                                                          [ 42%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_tuple_tensor_valid PASSED                                                                                                                                                                                                                                                                                                   [ 50%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_inconsistent_shapes_in_list PASSED                                                                                                                                                                                                                                                                                          [ 57%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_empty_list PASSED                                                                                                                                                                                                                                                                                                           [ 64%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_validation_disabled_skips_shape_check PASSED                                                                                                                                                                                                                                                                                [ 71%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_with_valid_resolve_binding_dims PASSED                                                                                                                                                                                                                                                                                      [ 78%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_with_invalid_resolve_binding_dims PASSED                                                                                                                                                                                                                                                                                    [ 85%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_with_list_of_symbolic_dim PASSED                                                                                                                                                                                                                                                                                            [ 92%]
tests/standalone_tests/test_tensor_schema.py::test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length PASSED                                                                                                                                                                                                                                                                         [100%]

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR migrates LlavaImageInputs to TensorSchema for better input validation. An edge case where an empty list of pixel_values could cause a server crash was identified and should be addressed.

Comment on lines +594 to 589
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When pixel_values is an empty list, flatten_bn(pixel_values, concat=True) on line 597 will raise an exception because it calls torch.cat([]). This can crash the server if a request has an empty list for pixel_values. Add a check for an empty pixel_values list before this block to prevent this.

if pixel_values:
            expected_h = expected_w = self.config.vision_config.image_size
            return LlavaImagePixelInputs(
                type="pixel_values",
                pixel_values=flatten_bn(pixel_values, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w
                },
            )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a valid concern, but I'd like to avoid introducing new enforcement with the migration. Happy to update if others feel it'd be helpful.

Comment on lines 58 to 57
Copy link
Contributor Author

@bbeckca bbeckca Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite this comment about varying height/width, I'm adding enforcement to match existing behaviors in _validate_pixel_values. Please feel free to correct. @DarkLight1337 @Isotr0py

    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

Comment on lines 73 to 72
Copy link
Contributor Author

@bbeckca bbeckca Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No enforcement was previously applied, so I skipped adding validations against (num_channels, height, width). Feel free to let me know if there's other preferences. cc @DarkLight1337 @Isotr0py

@Isotr0py Isotr0py enabled auto-merge (squash) August 5, 2025 16:47
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 5, 2025
auto-merge was automatically disabled August 6, 2025 15:07

Head branch was pushed to by a user without write access

@Isotr0py Isotr0py enabled auto-merge (squash) August 6, 2025 15:14
@DarkLight1337
Copy link
Member

Can you merge from main? It should fix the CI

auto-merge was automatically disabled August 7, 2025 14:30

Head branch was pushed to by a user without write access

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 8, 2025

Can you merge from main? It should fix the CI

Able to reproduce the main branch CI failures locally. These appear unrelated to this PR. Will rebase once the upstream issue is fixed.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 8, 2025

Retrying MM test

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 8, 2025

Retrying MM test

Sorry I missed that. Will take a closer look.

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 9, 2025

Took a closer look, but the MM test failure seems to happen with latest on main. It seems related to downloading image for Pixtral, so unrelated to these changes?

tests/models/multimodal/generation/test_pixtral.py:116: in <module>
    _create_engine_inputs(IMG_URLS),
tests/models/multimodal/generation/test_pixtral.py:76: in _create_engine_inputs
    tokenized = tokenizer.encode_chat_completion(request)
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/mistral.py:379: in encode_chat_completion
    return self.instruct_tokenizer.encode_instruct(instruct_request)
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/instruct.py:179: in encode_instruct
    new_tokens, new_images, new_audios = self.encode_user_message(
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/instruct.py:449: in encode_user_message
    tokens, image, audio = self.encode_user_content(
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/instruct.py:762: in encode_user_content
    chunk_tokens, chunk_image, _ = self._encode_content_chunk(chunk)
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/instruct.py:688: in _encode_content_chunk
    img_encoding = self.image_encoder(chunk)
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/image.py:224: in __call__
    image = image_from_chunk(content)
venv/lib/python3.9/site-packages/mistral_common/tokens/tokenizers/image.py:92: in image_from_chunk
    return download_image(chunk.get_url())
venv/lib/python3.9/site-packages/mistral_common/image.py:33: in download_image
    raise RuntimeError(f"Error downloading the image from {url}: {e}.")
E   RuntimeError: Error downloading the image from https://picsum.photos/id/27/500/500: 525 Server Error: <none> for url: https://picsum.photos/id/27/500/500.

Signed-off-by: Benji Beck <benjibeck@meta.com>
@DarkLight1337
Copy link
Member

Retrying

@vllm-bot vllm-bot merged commit 06da44f into vllm-project:main Aug 11, 2025
35 of 43 checks passed
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
Signed-off-by: Benji Beck <benjibeck@meta.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Benji Beck <benjibeck@meta.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Benji Beck <benjibeck@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants