Skip to content

Commit 01fd287

Browse files
uyzhanggemini-code-assist[bot]ywang96
authored andcommitted
[Model][VLM] Support Bee-8B Model (vllm-project#27012)
Signed-off-by: uyzhang <yi.zhang.4096@gmail.com> Signed-off-by: Yi Zhang <zhangyi970819@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent e56a49b commit 01fd287

File tree

7 files changed

+228
-0
lines changed

7 files changed

+228
-0
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
634634
|--------------|--------|--------|-------------------|----------------------|---------------------------|
635635
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
636636
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ |
637+
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
637638
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
638639
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
639640
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ |

examples/offline_inference/vision_language.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,33 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
9090
)
9191

9292

93+
# Bee-8B
94+
def run_bee(questions: list[str], modality: str) -> ModelRequestData:
95+
assert modality == "image"
96+
model_name = "Open-Bee/Bee-8B-RL"
97+
98+
prompts = [
99+
(
100+
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
101+
f"<|im_start|>user\n<image>\n{question}<|im_end|>"
102+
f"<|im_start|>assistant\n<think>\n"
103+
)
104+
for question in questions
105+
]
106+
107+
engine_args = EngineArgs(
108+
model=model_name,
109+
max_model_len=16384,
110+
limit_mm_per_prompt={modality: 1},
111+
trust_remote_code=True,
112+
)
113+
114+
return ModelRequestData(
115+
engine_args=engine_args,
116+
prompts=prompts,
117+
)
118+
119+
93120
# BLIP-2
94121
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
95122
assert modality == "image"
@@ -1708,6 +1735,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
17081735
model_example_map = {
17091736
"aria": run_aria,
17101737
"aya_vision": run_aya_vision,
1738+
"bee": run_bee,
17111739
"blip-2": run_blip2,
17121740
"chameleon": run_chameleon,
17131741
"dots_ocr": run_dots_ocr,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,41 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
107107
)
108108

109109

110+
def load_bee(question: str, image_urls: list[str]) -> ModelRequestData:
111+
model_name = "Open-Bee/Bee-8B-RL"
112+
113+
engine_args = EngineArgs(
114+
model=model_name,
115+
max_model_len=16384,
116+
max_num_seqs=16,
117+
limit_mm_per_prompt={"image": len(image_urls)},
118+
trust_remote_code=True,
119+
)
120+
121+
placeholders = [{"type": "image", "image": url} for url in image_urls]
122+
messages = [
123+
{
124+
"role": "user",
125+
"content": [
126+
*placeholders,
127+
{"type": "text", "text": question},
128+
],
129+
}
130+
]
131+
132+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
133+
134+
prompt = processor.apply_chat_template(
135+
messages, tokenize=False, add_generation_prompt=True
136+
)
137+
138+
return ModelRequestData(
139+
engine_args=engine_args,
140+
prompt=prompt,
141+
image_data=[fetch_image(url) for url in image_urls],
142+
)
143+
144+
110145
def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData:
111146
model_name = "CohereLabs/command-a-vision-07-2025"
112147

@@ -1215,6 +1250,7 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
12151250
model_example_map = {
12161251
"aria": load_aria,
12171252
"aya_vision": load_aya_vision,
1253+
"bee": load_bee,
12181254
"command_a_vision": load_command_a_vision,
12191255
"deepseek_vl_v2": load_deepseek_vl2,
12201256
"gemma3": load_gemma3,

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def _test_processing_correctness_one(
326326
[
327327
"rhymes-ai/Aria",
328328
"CohereForAI/aya-vision-8b",
329+
"Open-Bee/Bee-8B-RL",
329330
"Salesforce/blip2-opt-2.7b",
330331
"facebook/chameleon-7b",
331332
"CohereLabs/command-a-vision-07-2025",

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,10 @@ def check_available_online(
566566
# [Decoder-only]
567567
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
568568
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"),
569+
"BeeForConditionalGeneration": _HfExamplesInfo(
570+
"Open-Bee/Bee-8B-RL",
571+
trust_remote_code=True,
572+
),
569573
"Blip2ForConditionalGeneration": _HfExamplesInfo(
570574
"Salesforce/blip2-opt-2.7b",
571575
extras={"6b": "Salesforce/blip2-opt-6.7b"},

vllm/model_executor/models/bee.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Mapping
5+
6+
import torch
7+
import torch.nn as nn
8+
from transformers.activations import GELUActivation
9+
10+
from vllm.config import VllmConfig
11+
from vllm.config.multimodal import BaseDummyOptions
12+
from vllm.multimodal import MULTIMODAL_REGISTRY
13+
from vllm.multimodal.inputs import MultiModalDataDict
14+
15+
from .llava_next import (
16+
LlavaDummyInputsBuilder,
17+
LlavaNextMultiModalProcessor,
18+
LlavaNextProcessingInfo,
19+
)
20+
from .llava_onevision import LlavaOnevisionForConditionalGeneration
21+
from .utils import WeightsMapper
22+
23+
24+
class BeeProcessingInfo(LlavaNextProcessingInfo):
25+
def get_hf_config(self):
26+
return self.ctx.get_hf_config()
27+
28+
def get_hf_processor(self, **kwargs: object):
29+
return self.ctx.get_hf_processor(**kwargs)
30+
31+
def _get_num_unpadded_features(
32+
self,
33+
*,
34+
original_height: int,
35+
original_width: int,
36+
npatches: int,
37+
num_patch_height: int,
38+
num_patch_width: int,
39+
) -> tuple[int, int]:
40+
"""Override to use correct max_num_patches from vision_aspect_ratio."""
41+
import math
42+
43+
current_height = npatches * num_patch_height
44+
current_width = npatches * num_patch_width
45+
46+
aspect_ratio = original_width / original_height
47+
current_aspect_ratio = current_width / current_height
48+
49+
if aspect_ratio > current_aspect_ratio:
50+
new_height = int(
51+
round(original_height * (current_width / original_width), 7)
52+
)
53+
padding = (current_height - new_height) // 2
54+
current_height = current_height - (2 * padding)
55+
else:
56+
new_width = int(
57+
round(original_width * (current_height / original_height), 7)
58+
)
59+
padding = (current_width - new_width) // 2
60+
current_width = current_width - (2 * padding)
61+
62+
unpadded_features = current_height * current_width
63+
newline_features = current_height
64+
65+
# Get max_num_patches from vision_aspect_ratio config
66+
hf_config = self.get_hf_config()
67+
vision_aspect_ratio = getattr(hf_config, "vision_aspect_ratio", "anyres_max_9")
68+
max_num_patches = int(vision_aspect_ratio.replace("anyres_max_", ""))
69+
70+
ratio = math.sqrt(
71+
current_height * current_width / (max_num_patches * npatches**2)
72+
)
73+
if ratio > 1.1:
74+
height_factor = int(current_height // ratio)
75+
width_factor = int(current_width // ratio)
76+
unpadded_features = height_factor * width_factor
77+
newline_features = height_factor
78+
79+
return (unpadded_features, newline_features)
80+
81+
82+
class BeeDummyInputsBuilder(LlavaDummyInputsBuilder[BeeProcessingInfo]):
83+
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
84+
num_images = mm_counts.get("image", 0)
85+
image_token = "<image>"
86+
87+
return image_token * num_images
88+
89+
def get_dummy_mm_data(
90+
self,
91+
seq_len: int,
92+
mm_counts: Mapping[str, int],
93+
mm_options: Mapping[str, BaseDummyOptions] | None = None,
94+
) -> MultiModalDataDict:
95+
num_images = mm_counts.get("image", 0)
96+
97+
target_width, target_height = self.info.get_image_size_with_most_features()
98+
99+
image_overrides = mm_options.get("image") if mm_options else None
100+
101+
return {
102+
"image": self._get_dummy_images(
103+
width=target_width,
104+
height=target_height,
105+
num_images=num_images,
106+
overrides=image_overrides,
107+
),
108+
}
109+
110+
111+
class BeeMultiModalProjector(nn.Module):
112+
def __init__(self, config):
113+
super().__init__()
114+
self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-06)
115+
self.linear_1 = nn.Linear(
116+
config.vision_config.hidden_size,
117+
config.text_config.hidden_size * 4,
118+
bias=True,
119+
)
120+
self.act = GELUActivation()
121+
self.linear_2 = nn.Linear(
122+
config.text_config.hidden_size * 4,
123+
config.text_config.hidden_size,
124+
bias=True,
125+
)
126+
127+
def forward(self, image_feature: torch.Tensor) -> torch.Tensor:
128+
image_feature = self.pre_norm(image_feature)
129+
hidden_states = self.linear_1(image_feature)
130+
hidden_states = self.act(hidden_states)
131+
hidden_states = self.linear_2(hidden_states)
132+
133+
return hidden_states
134+
135+
136+
@MULTIMODAL_REGISTRY.register_processor(
137+
LlavaNextMultiModalProcessor,
138+
info=BeeProcessingInfo,
139+
dummy_inputs=BeeDummyInputsBuilder,
140+
)
141+
class BeeForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
142+
hf_to_vllm_mapper = WeightsMapper(
143+
orig_to_new_prefix={
144+
# mapping for new names in checkpoint saved after transformers
145+
# v4.55
146+
"model.language_model.": "language_model.model.",
147+
"model.vision_tower.": "vision_tower.",
148+
"model.multi_modal_projector.": "multi_modal_projector.",
149+
"model.image_newline": "image_newline",
150+
"lm_head.": "language_model.lm_head.",
151+
}
152+
)
153+
154+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
155+
super().__init__(vllm_config=vllm_config, prefix=prefix)
156+
config = vllm_config.model_config.hf_config
157+
self.multi_modal_projector = BeeMultiModalProjector(config)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@
247247
"aya_vision",
248248
"AyaVisionForConditionalGeneration",
249249
),
250+
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
250251
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
251252
"ChameleonForConditionalGeneration": (
252253
"chameleon",

0 commit comments

Comments
 (0)