Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Model] Add V1 support for Qwen2-VL #11668

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

imkero
Copy link
Contributor

@imkero imkero commented Jan 1, 2025

What's changed:

  1. Allow using a function to determine dynamic dimensions of a tensor while torch.compile (M-RoPE uses a 2d position tensor which differs from common RoPE, and they share same impl in Qwen2 LM's forward fn)
  2. Modify dummy data retrival in profile_run for Qwen2-VL launch
  3. Add M-RoPE support to V1 gpu_model_runner
  4. Add support of encoder output in tuple (embeddings: torch.Tensor, modality: str) in gpu_model_runner for Qwen2-VL
  5. Use token_id instead of token str of image_token and video_token in Qwen2-VL's preprocessing for better performance

This PR should make Qwen2-VL works in V1 with chunked prefill and prefix caching enabled.

imkero added 3 commits January 1, 2025 13:50
Signed-off-by: imkero <kerorek@outlook.com>
Signed-off-by: imkero <kerorek@outlook.com>
Signed-off-by: imkero <kerorek@outlook.com>
Copy link

github-actions bot commented Jan 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@@ -791,6 +791,7 @@ def _parse_video_data(


class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
_placeholder_map: Optional[dict[str, list[int]]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should initialize this in the init method to avoid confusing it with a static class variable.

Copy link
Member

@DarkLight1337 DarkLight1337 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from this, the processor-related changes in the model file LGTM.

@ywang96
Copy link
Member

ywang96 commented Jan 1, 2025

Hello @imkero! Much appreciated that you made this PR!

The reason why I haven't spent too much on Qwen2-VL is that I want to see if there's a way to move MRope inside model file for Qwen2-VL since it is so specific to this model.

You would also need to change the implementation of _process_image_input and _process_video_input for this model to make it work properly on V1 (the returned embeddings need to be a NestedTensor, with the first dimension matching the total number of multimodal data items involved in the batch for fine-grained scheduling).

Feel free to take changes from here into this PR.

Comment on lines +829 to +838
if not self._placeholder_map:
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
encode_fn = hf_processor.tokenizer.encode
self._placeholder_map = {
"image": encode_fn(hf_processor.image_token),
"video": encode_fn(hf_processor.video_token),
}
placeholder = self._placeholder_map

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can set this at initialization time.

Comment on lines +579 to +582
encoder_outputs.append((
encoder_output[0]
[start_idx:end_idx], # embedding tensor
encoder_output[1], # modality
Copy link
Member

@ywang96 ywang96 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought is we don't necessarily need to have the modality key here.

We can leverage the fact that any two mm_position's from any modalities cannot possibily have overlaps, and now that

def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, List[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.

can apply the embedding replacement based on a list of token ids (so we can simply have [self.config.image_token_id, self.config.video_token_id] here)

Therefore, all we need to do should be just sorting mm_position's and their correpsonding mm_inputs in the following code(which also needs to be modified to support video modality for Qwen2VL in this PR)

vllm/vllm/v1/request.py

Lines 51 to 59 in 11d8a09

# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought - let me actually work on this design for llava-onevision too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants