Skip to content

Conversation

princepride
Copy link
Contributor

@princepride princepride commented May 31, 2025

Add Tariser model support: #9707

FIX #9707

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) labels May 31, 2025
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@princepride princepride changed the title [New model support]add tarsier model support Add tarsier model support May 31, 2025
@DarkLight1337
Copy link
Member

PTAL at the failing tests

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@princepride princepride requested a review from DarkLight1337 June 1, 2025 23:28
@princepride
Copy link
Contributor Author

@DarkLight1337 It seems that some other model has an error.

@DarkLight1337
Copy link
Member

cc @Isotr0py do you have time to help validate this model?

@princepride
Copy link
Contributor Author

from pathlib import Path
from vllm import LLM, SamplingParams
from PIL import Image
from vllm.multimodal.video import VideoMediaIO, ImageMediaIO

def extract_frames_from_video(video_filepath: str, num_frames: int):
    image_io = ImageMediaIO()
    video_io = VideoMediaIO(image_io=image_io, num_frames=num_frames)
    frames = video_io.load_file(Path(video_filepath))
    return frames

if __name__ == "__main__":
    EXAMPLE_IMAGE_PATH = "kitty.jpg"
    EXAMPLE_VIDEO_PATH = "kitchen.mp4"
    MAX_VIDEO_FRAMES = 4
    llm = LLM(model="omni-research/Tarsier-7b", trust_remote_code=True)
    sampling_params = SamplingParams(temperature=0.1, top_p=0.9, max_tokens=500)

    # Scenario 1: Pure text test
    print(f"\n--- Pure Text Test ---")
    vllm_inputs_text_only = {"prompt": "USER: Please introduce yourself. ASSISTANT:"}
    outputs = llm.generate(vllm_inputs_text_only, sampling_params)
    for output_item in outputs:
        print(f"Generated: {output_item.outputs[0].text}\n" + "-" * 20)

    # Scenario 2: Text and single image test
    print(f"\n--- Text and Single Image Test ---")
    vllm_inputs_single_image = {
        "prompt": "USER: <image>\nPlease describe the image. ASSISTANT:",
        "multi_modal_data": {"image": [Image.open(EXAMPLE_IMAGE_PATH).convert('RGB')]}
    }
    outputs = llm.generate(vllm_inputs_single_image, sampling_params) # Direct generation
    for output_item in outputs:
        print(f"Generated: {output_item.outputs[0].text}\n" + "-" * 20)

    # Scenario 3: Text and video (multiple frames) test
    vllm_inputs_video = {
        "prompt": f"USER: {'<image>'*MAX_VIDEO_FRAMES}\nPlease describe the video. ASSISTANT:",
        "multi_modal_data": {"image": extract_frames_from_video(EXAMPLE_VIDEO_PATH, MAX_VIDEO_FRAMES)}
    }
    outputs = llm.generate(vllm_inputs_video, sampling_params) # Direct generation
    for output_item in outputs:
        print(f"Generated: {output_item.outputs[0].text}\n" + "-" * 20)

    print("\nAll tests completed.")

Here is my simple test code, you can refer it

@Isotr0py
Copy link
Member

Isotr0py commented Jun 2, 2025

--- Pure Text Test ---
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 195.30it/s]
Processed prompts: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.05s/it, est. speed input: 12.42 toks/s, output: 23.88 toks/s]
Generated:  I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).
--------------------

--- Text and Single Image Test ---
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.07s/it]
Processed prompts: 100%|█████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.95s/it, est. speed input: 69.08 toks/s, output: 27.50 toks/s]
Generated: The image captures a vibrant street scene in Chinatown, Melbourne, Australia. Dominating the foreground is a red octagonal stop sign, standing resolute on the sidewalk. It's a familiar sight, a universal symbol instructing drivers to halt their vehicles. Just beyond the sign, the street unfolds in a lively display of culture and commerce. A red lantern hangs from the side of a building, its color matching the stop sign and adding to the festive atmosphere. The building itself is a mix of traditional and modern architecture, with a green awning providing a pop of color against the urban landscape. The street is bustling with activity. People are seen walking on the sidewalk, adding a dynamic element to the scene. Cars are parked along the street, their metallic bodies gleaming under the sunlight. Above it all, the sky stretches out in a clear blue expanse, dotted here and there with trees that add a touch of nature to the urban setting. The image is a snapshot of everyday life in Melbourne, capturing the city's vibrant street scenes and multicultural atmosphere.
--------------------

--- Text and video (multiple frames) test ---
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.73it/s]
Processed prompts: 100%|████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.89s/it, est. speed input: 837.87 toks/s, output: 18.00 toks/s]
Generated:  A young child is sitting on a bed, holding and interacting with a book. The child flips through the pages of the book, occasionally looking down at it. The background shows a bed with a blanket and some clothes scattered on it.
--------------------

According to the model outputs, model implementation should be fine.

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the document's supported_models and vllm/entrypoints/chat_utils.py?

def _placeholder_str(self, modality: ModalityStr,
current_count: int) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)

… entrypoints placeholder

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@princepride princepride requested a review from hmellor as a code owner June 2, 2025 09:47
@mergify mergify bot added the frontend label Jun 2, 2025
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@princepride princepride requested a review from Isotr0py June 2, 2025 12:38
Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Processor tests passed on my side locally as well. So this PR should be good to go!

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 2, 2025
@princepride princepride requested a review from Isotr0py June 2, 2025 22:40
@princepride
Copy link
Contributor Author

@Isotr0py Sorry to bother you to review it again. This morning I noticed all 75 checks have passed but a button "update branch" occurred. I directly clicked the button and the whole process restarted again. Please forgive me for being a rookie. Do I not need to click the update branch button after all checks are passed?

@Isotr0py
Copy link
Member

Isotr0py commented Jun 3, 2025

Hmmm, I remember we have disabled this button before... No need to update, I think this limitation would be disabled again, and we can merge this directly then.

@princepride
Copy link
Contributor Author

@DarkLight1337 Can you merge it, thank you.

@Isotr0py Isotr0py merged commit 1282bd8 into vllm-project:main Jun 3, 2025
67 checks passed
@DarkLight1337 DarkLight1337 mentioned this pull request Jun 3, 2025
1 task
@patil-suraj
Copy link

Does this support https://huggingface.co/omni-research/Tarsier2-Recap-7b as well ?

@princepride
Copy link
Contributor Author

Does this support https://huggingface.co/omni-research/Tarsier2-Recap-7b as well ?

No, because they have the different architecture, I read their paper, the tarsier2 is fine tune from qwen2-vl, have you tried use qwen2-vl to load tarsier2? If can't, I can try to support the tarsier2 model

@princepride
Copy link
Contributor Author

you can pass hf_overrides={"architectures": ["Qwen2VLForConditionalGeneration"]} to try it

@princepride
Copy link
Contributor Author

@patil-suraj can you use Qwen2VLForConditionalGeneration run the Tarsier2-Recap-7b? I got an error:

ERROR 06-17 06:56:48 [core.py:520] EngineCore failed to start.
ERROR 06-17 06:56:48 [core.py:520] Traceback (most recent call last):
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/engine/core.py", line 511, in run_engine_core
ERROR 06-17 06:56:48 [core.py:520]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/engine/core.py", line 395, in __init__
ERROR 06-17 06:56:48 [core.py:520]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/engine/core.py", line 76, in __init__
ERROR 06-17 06:56:48 [core.py:520]     self.model_executor = executor_class(vllm_config)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/executor/executor_base.py", line 53, in __init__
ERROR 06-17 06:56:48 [core.py:520]     self._init_executor()
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 47, in _init_executor
ERROR 06-17 06:56:48 [core.py:520]     self.collective_rpc("init_device")
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 06-17 06:56:48 [core.py:520]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/utils.py", line 2690, in run_method
ERROR 06-17 06:56:48 [core.py:520]     return func(*args, **kwargs)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/worker/worker_base.py", line 606, in init_device
ERROR 06-17 06:56:48 [core.py:520]     self.worker.init_device()  # type: ignore
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/worker/gpu_worker.py", line 165, in init_device
ERROR 06-17 06:56:48 [core.py:520]     self.model_runner: GPUModelRunner = GPUModelRunner(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/worker/gpu_model_runner.py", line 134, in __init__
ERROR 06-17 06:56:48 [core.py:520]     encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 95, in compute_encoder_budget
ERROR 06-17 06:56:48 [core.py:520]     ) = _compute_encoder_budget_multimodal(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 125, in _compute_encoder_budget_multimodal
ERROR 06-17 06:56:48 [core.py:520]     .get_max_tokens_per_item_by_nonzero_modality(model_config)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/multimodal/registry.py", line 158, in get_max_tokens_per_item_by_nonzero_modality
ERROR 06-17 06:56:48 [core.py:520]     self.get_max_tokens_per_item_by_modality(model_config).items()
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/multimodal/registry.py", line 132, in get_max_tokens_per_item_by_modality
ERROR 06-17 06:56:48 [core.py:520]     return profiler.get_mm_max_tokens(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/multimodal/profiling.py", line 256, in get_mm_max_tokens
ERROR 06-17 06:56:48 [core.py:520]     mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/multimodal/profiling.py", line 166, in _get_dummy_mm_inputs
ERROR 06-17 06:56:48 [core.py:520]     processor_inputs = factory.get_dummy_processor_inputs(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/multimodal/profiling.py", line 91, in get_dummy_processor_inputs
ERROR 06-17 06:56:48 [core.py:520]     dummy_text = self.get_dummy_text(mm_counts)
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/model_executor/models/qwen2_vl.py", line 973, in get_dummy_text
ERROR 06-17 06:56:48 [core.py:520]     hf_processor = self.info.get_hf_processor()
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/model_executor/models/qwen2_vl.py", line 762, in get_hf_processor
ERROR 06-17 06:56:48 [core.py:520]     image_processor=self.get_image_processor(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/model_executor/models/qwen2_vl.py", line 811, in get_image_processor
ERROR 06-17 06:56:48 [core.py:520]     return cached_image_processor_from_config(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/transformers_utils/processor.py", line 216, in cached_image_processor_from_config
ERROR 06-17 06:56:48 [core.py:520]     return cached_get_image_processor(
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/transformers_utils/processor.py", line 204, in get_image_processor
ERROR 06-17 06:56:48 [core.py:520]     raise e
ERROR 06-17 06:56:48 [core.py:520]   File "/workspace/vllm/vllm/transformers_utils/processor.py", line 185, in get_image_processor
ERROR 06-17 06:56:48 [core.py:520]     processor = AutoImageProcessor.from_pretrained(
ERROR 06-17 06:56:48 [core.py:520]   File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/image_processing_auto.py", line 564, in from_pretrained
ERROR 06-17 06:56:48 [core.py:520]     return image_processor_class.from_dict(config_dict, **kwargs)
ERROR 06-17 06:56:48 [core.py:520]   File "/usr/local/lib/python3.10/dist-packages/transformers/image_processing_base.py", line 422, in from_dict
ERROR 06-17 06:56:48 [core.py:520]     image_processor = cls(**image_processor_dict)
ERROR 06-17 06:56:48 [core.py:520]   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 143, in __init__
ERROR 06-17 06:56:48 [core.py:520]     raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
ERROR 06-17 06:56:48 [core.py:520] ValueError: size must contain 'shortest_edge' and 'longest_edge' keys.
Process EngineCore_0:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/v1/engine/core.py", line 524, in run_engine_core
    raise e
  File "/workspace/vllm/vllm/v1/engine/core.py", line 511, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
  File "/workspace/vllm/vllm/v1/engine/core.py", line 395, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/workspace/vllm/vllm/v1/engine/core.py", line 76, in __init__
    self.model_executor = executor_class(vllm_config)
  File "/workspace/vllm/vllm/executor/executor_base.py", line 53, in __init__
    self._init_executor()
  File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 47, in _init_executor
    self.collective_rpc("init_device")
  File "/workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/workspace/vllm/vllm/utils.py", line 2690, in run_method
    return func(*args, **kwargs)
  File "/workspace/vllm/vllm/worker/worker_base.py", line 606, in init_device
    self.worker.init_device()  # type: ignore
  File "/workspace/vllm/vllm/v1/worker/gpu_worker.py", line 165, in init_device
    self.model_runner: GPUModelRunner = GPUModelRunner(
  File "/workspace/vllm/vllm/v1/worker/gpu_model_runner.py", line 134, in __init__
    encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
  File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 95, in compute_encoder_budget
    ) = _compute_encoder_budget_multimodal(
  File "/workspace/vllm/vllm/v1/core/encoder_cache_manager.py", line 125, in _compute_encoder_budget_multimodal
    .get_max_tokens_per_item_by_nonzero_modality(model_config)
  File "/workspace/vllm/vllm/multimodal/registry.py", line 158, in get_max_tokens_per_item_by_nonzero_modality
    self.get_max_tokens_per_item_by_modality(model_config).items()
  File "/workspace/vllm/vllm/multimodal/registry.py", line 132, in get_max_tokens_per_item_by_modality
    return profiler.get_mm_max_tokens(
  File "/workspace/vllm/vllm/multimodal/profiling.py", line 256, in get_mm_max_tokens
    mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
  File "/workspace/vllm/vllm/multimodal/profiling.py", line 166, in _get_dummy_mm_inputs
    processor_inputs = factory.get_dummy_processor_inputs(
  File "/workspace/vllm/vllm/multimodal/profiling.py", line 91, in get_dummy_processor_inputs
    dummy_text = self.get_dummy_text(mm_counts)
  File "/workspace/vllm/vllm/model_executor/models/qwen2_vl.py", line 973, in get_dummy_text
    hf_processor = self.info.get_hf_processor()
  File "/workspace/vllm/vllm/model_executor/models/qwen2_vl.py", line 762, in get_hf_processor
    image_processor=self.get_image_processor(
  File "/workspace/vllm/vllm/model_executor/models/qwen2_vl.py", line 811, in get_image_processor
    return cached_image_processor_from_config(
  File "/workspace/vllm/vllm/transformers_utils/processor.py", line 216, in cached_image_processor_from_config
    return cached_get_image_processor(
  File "/workspace/vllm/vllm/transformers_utils/processor.py", line 204, in get_image_processor
    raise e
  File "/workspace/vllm/vllm/transformers_utils/processor.py", line 185, in get_image_processor
    processor = AutoImageProcessor.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/image_processing_auto.py", line 564, in from_pretrained
    return image_processor_class.from_dict(config_dict, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/image_processing_base.py", line 422, in from_dict
    image_processor = cls(**image_processor_dict)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 143, in __init__
    raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
ValueError: size must contain 'shortest_edge' and 'longest_edge' keys.

@patil-suraj
Copy link

@princepride Thank you!
I tried to load it using transformers directly and I can load the weights with Qwen2VLForConditionalGeneration. Haven't checked inference yet.

from transformers import Qwen2VLForConditionalGenerationmodel
model = Qwen2VLForConditionalGeneration.from_pretrained("./Tarsier2-Recap-7b/")

The error you have seems to be coming from the processor which probably has different/missing keys.

@Isotr0py
Copy link
Member

@princepride Seems tarsier2 is using a different processor instead of qwen2-vl processor: https://github.com/bytedance/tarsier/blob/4581ecf213596f50c9b38fff753facf07f198b94/dataset/tarsier_processor.py#L24-L29

Perhaps we need to register a different multimodal processor for it just like Mantis.

@patil-suraj
Copy link

That's correct, the inference code actually uses this
https://github.com/bytedance/tarsier/blob/tarsier2/dataset/tarsier_datamodule.py#L224

@princepride
Copy link
Contributor Author

@princepride Seems tarsier2 is using a different processor instead of qwen2-vl processor: https://github.com/bytedance/tarsier/blob/4581ecf213596f50c9b38fff753facf07f198b94/dataset/tarsier_processor.py#L24-L29

Perhaps we need to register a different multimodal processor for it just like Mantis.

👌, I will handle it

@princepride
Copy link
Contributor Author

@Isotr0py The Tarsier2 model's config file specifies the model_type as "llava". This causes vllm_config.model_config.hf_config to be automatically created as a LlavaConfig object. How should I resolve this?

@DarkLight1337
Copy link
Member

You should also override it using hf_overrides

@princepride
Copy link
Contributor Author

It seems not work, the model still auto create a LlavaConfig object, can I manual mapping the config value and create a Qwen2VLConfig?

@DarkLight1337
Copy link
Member

What do you mean by manual mapping?

@princepride
Copy link
Contributor Author

princepride commented Jun 19, 2025

Sorry, the problem solved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[New Model]: Tarsier

4 participants